#!/usr/bin/env python3
"""
fix_query_chains.py - Fix unsafe PHP query chain patterns in CodeIgniter files.

Transforms unsafe patterns like:
    $data['proposal'] = $this->db->query($sql)->row_array();
Into safe patterns like:
    $_q = $this->db->query($sql); $data['proposal'] = $_q ? $_q->row_array() : array();

Usage:
    python3 fix_query_chains.py --dry-run path/to/file.php
    python3 fix_query_chains.py --dry-run path/to/directory/
    python3 fix_query_chains.py --apply path/to/file.php
"""

import argparse
import os
import re
import sys


# Method endings and their safe default values
METHOD_DEFAULTS = {
    'row_array':    'array()',
    'result_array': 'array()',
    'row':          'null',
    'result':       'null',
    'num_rows':     '0',
}

# Build a regex pattern for the unsafe query chain.
#
# Breakdown:
#   ^(\s*)                          - capture leading whitespace
#   (\$[\w>\-\[\]\'\"]+)            - capture assignment target (variable, array access, object property)
#   \s*=\s*                         - equals sign with optional whitespace
#   \$this->db->query\(            - literal start of query call
#   (.+)                            - capture the query argument (greedy, we refine with rfind below)
#   \)                              - closing paren of query()
#   ->                              - chain arrow
#   (row_array|result_array|row|result|num_rows)  - capture the chained method name
#   \(\)                            - empty parens of method call
#   \s*;                            - semicolon with optional whitespace
#   \s*$                            - end of line
#
# The tricky part is matching the query argument which may contain nested parentheses,
# string concatenation, function calls, etc. We handle this by:
# 1. Using a greedy match for the entire middle section
# 2. Then finding the correct closing paren by scanning for balanced parentheses

# This regex captures the whole line loosely; we do precise paren-balancing in code.
LINE_PATTERN = re.compile(
    r'^(\s*)'                                        # group 1: indentation
    r'('                                             # group 2: assignment target
        r'\$'                                        #   starts with $
        r'(?:'
            r'\w+'                                   #   variable name
            r'(?:'
                r'(?:\[[\'\"]?\w+[\'\"]?\])'         #   array access like ['key'] or [0]
                r'|(?:->[\w]+)'                      #   object property like ->prop
            r')*'
        r')'
    r')'
    r'\s*=\s*'                                       # = with whitespace
    r'\$this->db->query\('                           # $this->db->query(
    r'(.+)'                                          # group 3: everything after query( to end
)

# Return pattern: return $this->db->query(...)->method();
RETURN_PATTERN = re.compile(
    r'^(\s*)'                                        # group 1: indentation
    r'return\s+'                                     # return keyword
    r'\$this->db->query\('                           # $this->db->query(
    r'(.+)'                                          # group 2: everything after query(
)

# Already-safe pattern detector
SAFE_PATTERN = re.compile(r'^\s*\$_q\s*=\s*\$this->db->query\(')
SAFE_TERNARY = re.compile(r'\$_q\s*\?')


def find_balanced_close_paren(text):
    """Find the position of the closing paren that balances the first open context.

    We enter already past the opening '(' of query(...), so `text` starts inside.
    We need to find the matching ')' considering:
    - Nested parentheses
    - String literals (single and double quoted)
    - Escaped quotes inside strings

    Returns the index in `text` of the closing ')', or -1 if not found.
    """
    depth = 1
    i = 0
    length = len(text)
    while i < length:
        ch = text[i]

        # Handle string literals - skip their contents
        if ch in ('"', "'"):
            quote = ch
            i += 1
            while i < length:
                if text[i] == '\\':
                    i += 2  # skip escaped character
                    continue
                if text[i] == quote:
                    break
                i += 1
            # i is now on the closing quote
        elif ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth == 0:
                return i
        i += 1
    return -1


def transform_line(line):
    """Attempt to transform an unsafe query chain line into a safe pattern.

    Returns (new_line, was_changed) tuple.
    """
    # Skip lines that are already safe
    if SAFE_PATTERN.search(line) or SAFE_TERNARY.search(line):
        return line, False

    # Skip lines that don't contain the basic pattern
    if '$this->db->query(' not in line or '->' not in line:
        return line, False

    # Try assignment pattern first
    m = LINE_PATTERN.match(line)
    if m:
        indent = m.group(1)
        target = m.group(2)
        remainder = m.group(3)  # everything after "query("

        close_idx = find_balanced_close_paren(remainder)
        if close_idx < 0:
            return line, False

        query_arg = remainder[:close_idx]
        after_close = remainder[close_idx + 1:]

        chain_match = re.match(
            r'\s*->\s*(row_array|result_array|row|result|num_rows)\s*\(\s*\)\s*;\s*$',
            after_close
        )
        if not chain_match:
            return line, False

        method = chain_match.group(1)
        default = METHOD_DEFAULTS[method]

        new_line = (
            f"{indent}$_q = $this->db->query({query_arg}); "
            f"{target} = $_q ? $_q->{method}() : {default};\n"
        )
        return new_line, True

    # Try return pattern: return $this->db->query(...)->method();
    m_ret = RETURN_PATTERN.match(line)
    if m_ret:
        indent = m_ret.group(1)
        remainder = m_ret.group(2)

        close_idx = find_balanced_close_paren(remainder)
        if close_idx < 0:
            return line, False

        query_arg = remainder[:close_idx]
        after_close = remainder[close_idx + 1:]

        chain_match = re.match(
            r'\s*->\s*(row_array|result_array|row|result|num_rows)\s*\(\s*\)\s*;\s*$',
            after_close
        )
        if not chain_match:
            return line, False

        method = chain_match.group(1)
        default = METHOD_DEFAULTS[method]

        new_line = (
            f"{indent}$_q = $this->db->query({query_arg}); "
            f"return $_q ? $_q->{method}() : {default};\n"
        )
        return new_line, True

    return line, False


def process_file(filepath, apply_changes=False):
    """Process a single PHP file. Returns count of replacements."""
    try:
        with open(filepath, 'r', encoding='utf-8', errors='replace') as f:
            lines = f.readlines()
    except (IOError, OSError) as e:
        print(f"  ERROR: Cannot read {filepath}: {e}", file=sys.stderr)
        return 0

    new_lines = []
    changes = 0

    for line_num, line in enumerate(lines, start=1):
        new_line, changed = transform_line(line)
        new_lines.append(new_line)
        if changed:
            changes += 1
            print(f"  Line {line_num}:")
            print(f"    - {line.rstrip()}")
            print(f"    + {new_line.rstrip()}")

    if changes > 0 and apply_changes:
        try:
            with open(filepath, 'w', encoding='utf-8') as f:
                f.writelines(new_lines)
            print(f"  >> Written {changes} change(s) to {filepath}")
        except (IOError, OSError) as e:
            print(f"  ERROR: Cannot write {filepath}: {e}", file=sys.stderr)
            return 0

    return changes


def collect_php_files(path):
    """Collect .php files from a path (file or directory)."""
    if os.path.isfile(path):
        if path.endswith('.php'):
            return [path]
        else:
            print(f"  WARNING: {path} is not a .php file, processing anyway.")
            return [path]
    elif os.path.isdir(path):
        php_files = []
        for root, _dirs, files in os.walk(path):
            for fname in sorted(files):
                if fname.endswith('.php'):
                    php_files.append(os.path.join(root, fname))
        return php_files
    else:
        print(f"  ERROR: {path} not found.", file=sys.stderr)
        return []


def main():
    parser = argparse.ArgumentParser(
        description='Fix unsafe PHP query chain patterns ($this->db->query()->method()).'
    )
    mode = parser.add_mutually_exclusive_group(required=True)
    mode.add_argument('--dry-run', action='store_true',
                       help='Show changes without writing files.')
    mode.add_argument('--apply', action='store_true',
                       help='Apply changes to files in-place.')
    parser.add_argument('paths', nargs='+',
                        help='PHP file(s) or directory(ies) to process.')
    args = parser.parse_args()

    total_files = 0
    total_changes = 0
    files_changed = 0

    for path in args.paths:
        php_files = collect_php_files(path)
        for filepath in php_files:
            total_files += 1
            print(f"\n--- {filepath} ---")
            count = process_file(filepath, apply_changes=args.apply)
            if count > 0:
                files_changed += 1
                total_changes += count
            else:
                print("  (no unsafe patterns found)")

    print(f"\n{'=' * 60}")
    print(f"Summary: {total_changes} replacement(s) across {files_changed} file(s) "
          f"({total_files} file(s) scanned)")
    if args.dry_run and total_changes > 0:
        print("Mode: DRY RUN (no files were modified)")
    elif args.apply and total_changes > 0:
        print("Mode: APPLY (files were modified in-place)")


if __name__ == '__main__':
    main()
