From 14dfbaa6dab0f28685fe3794691891f2c96b28d2 Mon Sep 17 00:00:00 2001 From: raet Date: Tue, 20 Jan 2026 20:06:28 -0800 Subject: [PATCH] init --- README.md | 91 +++ auto_classes.json | 4 + config.json | 24 + program.json | 44 ++ push.py | 52 ++ pyproject.toml | 7 + regspy/__init__.py | 10 + regspy/cli.py | 1340 ++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 1572 insertions(+) create mode 100644 auto_classes.json create mode 100644 config.json create mode 100644 program.json create mode 100644 push.py create mode 100644 pyproject.toml create mode 100644 regspy/__init__.py create mode 100644 regspy/cli.py diff --git a/README.md b/README.md index 571e2d2..35bfb10 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,93 @@ + # regspy +regspy is a regex pattern generator, you enter some data -> select what you want matched and or not matched -> ??? -> Pattern! + +![alt text](imgs/demo.gif) + +This project started as me trying to learn dspy, its vibe coded to shit and back but it works and has some accomplishments: + - Runs on small models with 3B parameter at a minimum, so it should run on anything. + - It outperforms grex ~~in metrics that were defined by me~~. + - Learns from what you feed it, it generated a pattern you liked? add it to the training set! + - No human written prompts or rules or "make sure to NOT explode" bs. + - Context aware generation, it learns from failed patterns and most importantly WHY it failed. + - Generates patterns based on a scoring system that ranks patterns by: + - **matches_all**: Percentage of required items the pattern matches + - **excludes_all**: Percentage of excluded items the pattern avoids + - *If no excluded items are selected, this metrics weights are divided equally amongst the others.* + - **coherence**: How similar extra matches are to target items + - **generalization**: Use of character classes (\\d, \\w) vs literals + - **simplicity**: How short patterns are and without the use of branching + +Is it perfect? hell no, the training set, scoring system, hint generation could be improved upon, so if you want have a go at it i included a CLAUDE.md for you. + +But if you're a everyday smooth brain like me that needs a simple pattern on the fly because for some reason your brain is physically impossible of remembering that lookaheads exist, regspy should be of some help. + +## Features + +- **Visual Text Selection**: Highlight text to create match examples (cyan) or exclusions (red) +- **LLM-Powered Generation**: Uses local Ollama with qwen2.5-coder:3b for intelligent pattern creation +- **Training Dataset**: 227+ curated examples with ability to add your own +- **Pre-compilation**: Optional rule extraction for faster runtime inference +- **Session Config**: Adjust model, temperature, and scoring weights on the fly + +## Installation + +- **AutoHotkey v2.0** - [Download](https://www.autohotkey.com/) +- **Python Libs**: + ```bash + pip install dspy grex ollama + ``` +- **Ollama**: + ```bash + ollama serve + ollama pull qwen2.5-coder:3b + ``` +- **Run**: + ```bash + AutoHotkey64.exe regspy.ahk # Or just double click regspy.ahk + ``` + +### CLI flags + +```bash +# Run test suite +python regexgen.py --test + +# Pre-compile for faster runtime +python regexgen.py --compile + +# Generate regex from JSON input +python regexgen.py input.json output.json + +# With custom config +python regexgen.py input.json output.json --config config.json + +# Dataset management +python regexgen.py --list-dataset output.json +python regexgen.py --add-example example.json +python regexgen.py --delete-example +``` + +## Architecture + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ AutoHotkey │────▶│ Web Frontend │────▶│ Python │ +│ (Host) │◀────│ (WebView2) │◀────│ (DSPy/LLM) │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + Window Text Selection Regex Generation + Management Highlighting Multi-criteria + IPC Bridge Results Display Scoring +``` + +## Configuration + +The Config tab allows session-level adjustments: + +- **Model**: Ollama model name (default: `qwen2.5-coder:3b`) +- **Temperature**: LLM creativity (default: 0.4) +- **Max Attempts**: Refinement iterations (default: 10) +- **Reward Threshold**: Stop early if score exceeds (default: 0.85) +- **Scoring Weights**: Adjust the 5 criteria weights +- **Context Window** (`num_ctx`): Ollama context size (default: 8192). Ollama defaults to 4096 which can truncate prompts with many training examples. If you see "truncating input prompt" warnings in Ollama logs, bump this up. Uses ~200MB extra VRAM per 4K increase on 3B models. \ No newline at end of file diff --git a/auto_classes.json b/auto_classes.json new file mode 100644 index 0000000..9fe8e21 --- /dev/null +++ b/auto_classes.json @@ -0,0 +1,4 @@ +{ + "AutoConfig": "regspy.cli.RegexConfig", + "AutoProgram": "regspy.cli.RegexProgram" +} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..2b5e7d2 --- /dev/null +++ b/config.json @@ -0,0 +1,24 @@ +{ + "model": "qwen2.5-coder:3b", + "ollama_url": "http://localhost:11434", + "temperature": 0.4, + "num_ctx": 8192, + "enable_cache": false, + "max_attempts": 10, + "reward_threshold": 0.85, + "fail_count": null, + "use_cot": true, + "dataset_file": "/Users/fadel/Desktop/dev/regspy/dspy/regex-dspy-train.json", + "compiled_program_path": "/Users/fadel/Desktop/dev/regspy/dspy/regex_compiled.json", + "compile_threads": 8, + "compile_candidates": 16, + "compile_num_rules": 5, + "debug": true, + "weights": { + "matches_all": 0.35, + "excludes_all": 0.25, + "coherence": 0.15, + "generalization": 0.15, + "simplicity": 0.1 + } +} \ No newline at end of file diff --git a/program.json b/program.json new file mode 100644 index 0000000..c0063b5 --- /dev/null +++ b/program.json @@ -0,0 +1,44 @@ +{ + "program.predict": { + "traces": [], + "train": [], + "demos": [], + "signature": { + "instructions": "Generate a regex pattern from examples.", + "fields": [ + { + "prefix": "Text:", + "description": "The full text to search within" + }, + { + "prefix": "Match Items:", + "description": "Strings the pattern MUST match" + }, + { + "prefix": "Exclude Items:", + "description": "Strings the pattern must NOT match" + }, + { + "prefix": "Pattern Hints:", + "description": "Analysis hints about the match items" + }, + { + "prefix": "Reasoning: Let's think step by step in order to", + "description": "${reasoning}" + }, + { + "prefix": "Pattern:", + "description": "Regex pattern" + } + ] + }, + "lm": null + }, + "metadata": { + "dependency_versions": { + "python": "3.13", + "dspy": "3.1.2", + "cloudpickle": "3.1" + } + } +} \ No newline at end of file diff --git a/push.py b/push.py new file mode 100644 index 0000000..3d3cedb --- /dev/null +++ b/push.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +""" +Push RegSpy program to Modaic Hub. + +Usage: + uv run push.py your-username/regspy + uv run push.py your-username/regspy --with-code --commit-message "Initial" +""" + +from __future__ import annotations + +import argparse + +from regspy import RegexConfig, RegexProgram + + +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Push RegSpy PrecompiledProgram to Modaic Hub", + ) + parser.add_argument("repo", help="Hub repo in the form username/name") + parser.add_argument("--with-code", action="store_true", help="Bundle code with the push") + parser.add_argument("--commit-message", help="Optional commit message") + parser.add_argument("--branch", help="Optional branch name") + parser.add_argument("--tag", help="Optional tag name") + parser.add_argument("--private", action="store_true", help="Push to a private repo") + return parser + + +def main() -> None: + parser = create_parser() + args = parser.parse_args() + + program = RegexProgram(RegexConfig()) + + push_kwargs: dict[str, object] = { + "with_code": args.with_code, + "private": args.private, + } + if args.commit_message: + push_kwargs["commit_message"] = args.commit_message + if args.branch: + push_kwargs["branch"] = args.branch + if args.tag: + push_kwargs["tag"] = args.tag + + program.push_to_hub(args.repo, **push_kwargs) + print(f"Pushed to {args.repo}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c2e2963 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "regspy" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.13" +dependencies = ["dspy>=3.1.2", "grex>=1.0.2", "modaic>=0.10.3"] diff --git a/regspy/__init__.py b/regspy/__init__.py new file mode 100644 index 0000000..0b5e813 --- /dev/null +++ b/regspy/__init__.py @@ -0,0 +1,10 @@ +"""RegSpy package exports.""" + +from .cli import ( + RegexConfig, + RegexProgram, + compile_and_save, + generate_regex, +) + +__all__ = ["RegexConfig", "RegexProgram", "compile_and_save", "generate_regex"] diff --git a/regspy/cli.py b/regspy/cli.py new file mode 100644 index 0000000..b9b2b18 --- /dev/null +++ b/regspy/cli.py @@ -0,0 +1,1340 @@ +""" +Regex Generator - DSPy + Refine +================================ + +Generates regex patterns from example match/exclude strings using DSPy with +Ollama LLM. Uses a 5-weight reward system (matches_all, excludes_all, +coherence, generalization, simplicity) to iteratively refine patterns. + +Supports pre-compilation for faster runtime inference. + +Prerequisites: + pip install dspy grex modaic + ollama serve + +Usage: + python -m regspy --test # Run test cases + python -m regspy --compile # Pre-compile for faster runtime + python -m regspy # Generate regex + python -m regspy --config + python -m regspy --list-dataset # Export training dataset + python -m regspy --add-example # Add example to dataset + python -m regspy --delete-example # Delete example from dataset +""" + +import argparse +import json +import os +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +import dspy +from dspy.teleprompt import InferRules, LabeledFewShot +from grex import RegExpBuilder +from modaic import PrecompiledConfig, PrecompiledProgram + +# ============================================================================= +# Constants +# ============================================================================= + +SCRIPT_DIR = Path(__file__).parent.resolve() +PROJECT_ROOT = SCRIPT_DIR.parent + + +def resolve_path(relative_path: str) -> str: + """Resolve a relative path to an absolute path based on project root.""" + return str(PROJECT_ROOT / relative_path) + + +# ============================================================================= +# Configuration +# ============================================================================= + + +class RegexConfig(PrecompiledConfig): + """Configuration for the regex generator.""" + + # Model settings + model: str = "qwen2.5-coder:3b" + ollama_url: str = "http://localhost:11434" + temperature: float = 0.4 + num_ctx: int = 8192 + enable_cache: bool = False + + # Refine parameters + max_attempts: int = 10 + reward_threshold: float = 0.85 + fail_count: int | None = None + + # Few-shot settings + use_cot: bool = True + + # File paths + dataset_file: str = field(default_factory=lambda: resolve_path("dspy/regex-dspy-train.json")) + compiled_program_path: str = field(default_factory=lambda: resolve_path("dspy/regex_compiled.json")) + + # Compilation settings (InferRules optimizer) + compile_threads: int = 8 + compile_candidates: int = 16 + compile_num_rules: int = 5 + + # Debug output + debug: bool = True + + # Scoring weights + weights: dict = field(default_factory=lambda: { + "matches_all": 0.35, + "excludes_all": 0.25, + "coherence": 0.15, + "generalization": 0.15, + "simplicity": 0.10, + }) + + +# Alias for backward compatibility +Config = RegexConfig + + +# ============================================================================= +# DSPy Signature and Module +# ============================================================================= + + +class GenerateRegex(dspy.Signature): + """Generate a regex pattern from examples.""" + + text: str = dspy.InputField(desc="The full text to search within") + match_items: list[str] = dspy.InputField(desc="Strings the pattern MUST match") + exclude_items: list[str] = dspy.InputField(desc="Strings the pattern must NOT match") + pattern_hints: str = dspy.InputField(desc="Analysis hints about the match items") + pattern: str = dspy.OutputField(desc="Regex pattern") + + +class RegexProgram(PrecompiledProgram): + """Modaic-compatible DSPy program for regex generation.""" + + config: RegexConfig + + def __init__(self, config: RegexConfig, **kwargs): + super().__init__(config, **kwargs) + + if config.use_cot: + self.program = dspy.ChainOfThought(GenerateRegex) + else: + self.program = dspy.Predict(GenerateRegex) + + def forward(self, **kwargs): + return self.program(**kwargs) + + +def build_base_program(config: RegexConfig) -> RegexProgram: + """Create the base DSPy program used for compilation and inference.""" + return RegexProgram(config) + + +# ============================================================================= +# String Utilities +# ============================================================================= + + +def common_prefix(strings: list[str]) -> str: + """Find common prefix of all strings.""" + if not strings: + return "" + prefix = strings[0] + for s in strings[1:]: + while not s.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +def common_suffix(strings: list[str]) -> str: + """Find common suffix of all strings.""" + if not strings: + return "" + suffix = strings[0] + for s in strings[1:]: + while not s.endswith(suffix): + suffix = suffix[1:] + if not suffix: + return "" + return suffix + + +def safe_unescape(pattern: str) -> str: + """ + Safely unescape double-escaped regex metacharacters from LLM output. + Only unescapes known regex sequences to avoid breaking valid patterns. + """ + if not pattern or "\\\\" not in pattern: + return pattern + + metachar_escapes = ["d", "D", "w", "W", "s", "S", "b", "B", "n", "t", "r", "A", "Z"] + result = pattern + for char in metachar_escapes: + result = result.replace(f"\\\\{char}", f"\\{char}") + return result + + +# ============================================================================= +# Pattern Hint Analysis +# ============================================================================= + + +def _detect_case_patterns(match_items: list[str]) -> tuple[list[str], list[str]]: + """Detect uppercase/lowercase patterns in match items.""" + observations = [] + suggestions = [] + + if all(item.isalpha() and item.isupper() for item in match_items): + observations.append("All items are UPPERCASE") + suggestions.append("[A-Z]+") + elif all(item.isalpha() and item.islower() for item in match_items): + observations.append("All items are lowercase") + suggestions.append("[a-z]+") + + return observations, suggestions + + +def _detect_numeric_patterns(match_items: list[str]) -> tuple[list[str], list[str]]: + """Detect numeric/alphanumeric patterns in match items.""" + observations = [] + suggestions = [] + + if all(item.isdigit() for item in match_items): + observations.append("All items are numeric (digits only)") + suggestions.append(r"\d+") + elif all(item.isalnum() for item in match_items): + observations.append("All items are alphanumeric") + suggestions.append(r"\w+") + + if all(any(c.isdigit() for c in item) for item in match_items): + observations.append("All items contain digits") + suggestions.append(r"\d+") + + return observations, suggestions + + +def _detect_structure_patterns(match_items: list[str]) -> tuple[list[str], list[str]]: + """Detect structural patterns like word_digits, word-digits, etc.""" + structure_patterns = [ + (r"^[a-zA-Z]+_\d+$", r"\w+_\d+", "word_digits with underscore"), + (r"^[a-zA-Z]+-\d+$", r"\w+-\d+", "word-digits with hyphen"), + (r"^[a-zA-Z]+\.\d+$", r"\w+\.\d+", "word.digits with dot"), + (r"^[a-zA-Z]+:\d+$", r"\w+:\d+", "word:digits with colon"), + (r"^\d+_[a-zA-Z]+$", r"\d+_\w+", "digits_word with underscore"), + (r"^\d+-[a-zA-Z]+$", r"\d+-\w+", "digits-word with hyphen"), + (r"^[A-Z]\d+$", r"[A-Z]\d+", "single letter followed by digits"), + (r"^[A-Z]{1,3}\d+$", r"[A-Z]{1,3}\d+", "short prefix followed by digits"), + (r"^[a-z]+@[a-z]+\.[a-z]+$", r"[a-z]+@[a-z]+\.[a-z]+", "email-like pattern"), + ] + + for pattern, suggestion, description in structure_patterns: + if all(re.match(pattern, item, re.IGNORECASE) for item in match_items): + return [f"All items follow '{description}' structure"], [suggestion] + + return [], [] + + +def _detect_wrapper_patterns(match_items: list[str]) -> tuple[list[str], list[str]]: + """Detect wrapper patterns (e.g., <|im_start|>, {{foo}}, [bar]).""" + if len(match_items) <= 1: + return [], [] + + prefix = common_prefix(match_items) + suffix = common_suffix(match_items) + special_chars = set(r"[](){}|^$.*+?\/<>") + + has_special_prefix = prefix and any(c in special_chars for c in prefix) + has_special_suffix = suffix and any(c in special_chars for c in suffix) + + if has_special_prefix and has_special_suffix: + escaped_prefix = re.escape(prefix) + escaped_suffix = re.escape(suffix) + wrapper_pattern = f"{escaped_prefix}.+{escaped_suffix}" + return [f"Items wrapped in '{prefix}...{suffix}'"], [wrapper_pattern] + + return [], [] + + +def _detect_prefix_suffix(match_items: list[str]) -> tuple[list[str], list[str]]: + """Detect common prefix and suffix patterns.""" + if len(match_items) <= 1: + return [], [] + + observations = [] + suggestions = [] + + prefix = common_prefix(match_items) + if len(prefix) >= 1 and not prefix.isspace(): + observations.append(f"Common prefix: '{prefix}'") + if len(prefix) <= 3: + suggestions.append(re.escape(prefix)) + + suffix = common_suffix(match_items) + if len(suffix) >= 1 and not suffix.isspace(): + observations.append(f"Common suffix: '{suffix}'") + + return observations, suggestions + + +def _detect_exclusion_hints( + match_items: list[str], exclude_items: list[str] +) -> tuple[list[str], list[str], list[str]]: + """Detect hints based on what to exclude.""" + if not exclude_items: + return [], [], [] + + observations = [] + suggestions = [] + avoid = [] + + match_lengths = set(len(item) for item in match_items) + exclude_lengths = set(len(item) for item in exclude_items) + + if match_lengths.isdisjoint(exclude_lengths): + observations.append("Match items have different lengths than excludes") + + # Check for substring relationships + for exc in exclude_items: + for match in match_items: + if match in exc and match != exc: + observations.append( + f"'{match}' is substring of excluded '{exc}' - use \\b word boundaries" + ) + suggestions.append(r"\b") + avoid.append("Lookahead (?!) - use word boundaries \\b instead") + avoid.append("Patterns without word boundaries that match substrings") + break + + # Check length relationships + if match_lengths and exclude_lengths and max(match_lengths) < min(exclude_lengths): + observations.append( + "All matches are shorter than excludes - length-limited quantifiers may help" + ) + + return observations, suggestions, avoid + + +def analyze_match_items( + match_items: list[str], exclude_items: list[str] | None = None +) -> dict: + """ + Analyze match items to generate pattern hints for the LLM. + Returns a dict with observations, suggested_fragments, and avoid lists. + """ + exclude_items = exclude_items or [] + hints: dict[str, list[str]] = { + "observations": [], + "suggested_fragments": [], + "avoid": [], + } + + if not match_items: + return hints + + # Collect hints from various detectors + detectors = [ + lambda: _detect_case_patterns(match_items), + lambda: _detect_numeric_patterns(match_items), + lambda: _detect_structure_patterns(match_items), + lambda: _detect_wrapper_patterns(match_items), + lambda: _detect_prefix_suffix(match_items), + ] + + for detector in detectors: + obs, sugg = detector() + hints["observations"].extend(obs) + # Structure patterns should be high priority + if detector == detectors[2] and sugg: # structure detector + hints["suggested_fragments"] = sugg + hints["suggested_fragments"] + else: + hints["suggested_fragments"].extend(sugg) + + # Detect separators + separators = set() + for item in match_items: + for sep in ["-", "_", ".", ":", "/"]: + if sep in item: + separators.add(sep) + if separators: + hints["observations"].append(f"Contains separators: {separators}") + + # Check consistent length + lengths = [len(item) for item in match_items] + if len(set(lengths)) == 1: + hints["observations"].append(f"All items have length {lengths[0]}") + + # Exclusion hints + obs, sugg, avoid = _detect_exclusion_hints(match_items, exclude_items) + hints["observations"].extend(obs) + hints["suggested_fragments"].extend(sugg) + hints["avoid"].extend(avoid) + + return hints + + +def format_hints_for_prompt(hints: dict) -> str: + """Format hints dict into a string for the LLM prompt.""" + parts = [] + + if hints["observations"]: + parts.append("Observations: " + "; ".join(hints["observations"])) + if hints["suggested_fragments"]: + parts.append("Suggested fragments: " + ", ".join(hints["suggested_fragments"])) + if hints["avoid"]: + parts.append("Avoid: " + "; ".join(hints["avoid"])) + + return " | ".join(parts) if parts else "No specific patterns detected" + + +# ============================================================================= +# Pattern Analysis and Scoring +# ============================================================================= + + +class PatternAnalysis: + """Analyze a regex pattern for quality metrics.""" + + def __init__( + self, + pattern: str, + text: str, + match_items: list[str], + exclude_items: list[str], + ): + self.raw_pattern = pattern + self.pattern = safe_unescape(pattern) + self.text = text + self.match_items = match_items + self.exclude_items = exclude_items + self.matches: set[str] = set() + self.is_valid = False + self._analyze() + + def _analyze(self) -> None: + """Run the pattern and collect matches.""" + if not self.pattern: + return + try: + self.matches = set(re.findall(self.pattern, self.text)) + self.is_valid = True + except re.error: + self.is_valid = False + + def matches_all_required(self) -> float: + """Score: Does it match all required items? (0.0 to 1.0)""" + if not self.is_valid or not self.match_items: + return 0.0 + matched = sum(1 for item in self.match_items if item in self.matches) + return matched / len(self.match_items) + + def excludes_all_forbidden(self) -> float: + """Score: Does it avoid all excluded items? (0.0 to 1.0)""" + if not self.is_valid: + return 0.0 + if not self.exclude_items: + return 1.0 + excluded_matches = self.matches & set(self.exclude_items) + if not excluded_matches: + return 1.0 + return 1.0 - (len(excluded_matches) / len(self.exclude_items)) + + def uses_character_classes(self) -> float: + """Score: Does it use generic char classes? (0.0 to 1.0)""" + if not self.pattern: + return 0.0 + + score = 0.0 + + # Check for metacharacters + metachar_patterns = [r"\\d", r"\\w", r"\\s", r"\\b", r"\\D", r"\\W", r"\\S"] + for mp in metachar_patterns: + if re.search(mp, self.pattern): + score += 0.2 + + # Check for bracket character classes + bracket_classes = [ + r"\[A-Z\]", + r"\[a-z\]", + r"\[0-9\]", + r"\[A-Za-z\]", + r"\[\w-\]+", + ] + for bc in bracket_classes: + if re.search(bc, self.pattern, re.IGNORECASE): + score += 0.25 + + # Bonus for quantifiers on classes + if re.search(r"(?:\\[dwsb]|\[[^\]]+\])[+*?]|\{[\d,]+\}", self.pattern): + score += 0.15 + + return min(1.0, score) + + def simplicity_score(self) -> float: + """Score: Simpler patterns are better (0.0 to 1.0)""" + if not self.pattern: + return 0.0 + + # Length score + length = len(self.pattern) + if length < 5: + length_score = 0.3 + elif length <= 20: + length_score = 1.0 + elif length <= 40: + length_score = 0.8 + elif length <= 60: + length_score = 0.6 + elif length <= 100: + length_score = 0.4 + else: + length_score = 0.2 + + # Complexity score (cyclomatic-style) + alternations = self.pattern.count("|") + groups = self.pattern.count("(") + complexity = 1 + alternations + groups + + if complexity <= 2: + complexity_score = 1.0 + elif complexity <= 4: + complexity_score = 0.8 + elif complexity <= 7: + complexity_score = 0.6 + else: + complexity_score = 0.4 + + return 0.6 * length_score + 0.4 * complexity_score + + def coherence_score(self) -> float: + """Score: Are extra matches similar to intended match_items? (0.0 to 1.0)""" + if not self.matches or not self.match_items: + return 0.0 + + wanted = set(self.match_items) + excluded = set(self.exclude_items) if self.exclude_items else set() + extra = self.matches - wanted - excluded + + if not extra: + return 1.0 + + def get_char_profile(s: str) -> tuple[bool, bool, bool, bool]: + return ( + any(c.isupper() for c in s), + any(c.islower() for c in s), + any(c.isdigit() for c in s), + any(not c.isalnum() for c in s), + ) + + def get_bigrams(s: str) -> set[str]: + if len(s) < 2: + return {s} if s else set() + return set(s[i : i + 2] for i in range(len(s) - 1)) + + def pairwise_similarity(a: str, b: str) -> float: + if not a or not b: + return 0.0 + + max_len = max(len(a), len(b)) + len_sim = 1.0 - abs(len(a) - len(b)) / max_len + + profile_a = get_char_profile(a) + profile_b = get_char_profile(b) + profile_sim = sum(pa == pb for pa, pb in zip(profile_a, profile_b)) / 4.0 + + bigrams_a = get_bigrams(a) + bigrams_b = get_bigrams(b) + if bigrams_a and bigrams_b: + jaccard = len(bigrams_a & bigrams_b) / len(bigrams_a | bigrams_b) + else: + jaccard = 1.0 if a == b else 0.0 + + return 0.35 * len_sim + 0.40 * profile_sim + 0.25 * jaccard + + coherence_scores = [] + for item in extra: + max_sim = max(pairwise_similarity(item, ref) for ref in wanted) + coherence_scores.append(max_sim) + + return sum(coherence_scores) / len(coherence_scores) + + +def compute_scores(analysis: PatternAnalysis) -> dict[str, float]: + """Compute all individual scores from a pattern analysis.""" + return { + "matches_all": analysis.matches_all_required(), + "excludes_all": analysis.excludes_all_forbidden(), + "coherence": analysis.coherence_score(), + "generalization": analysis.uses_character_classes(), + "simplicity": analysis.simplicity_score(), + } + + +def compute_total_score( + scores: dict[str, float], weights: dict[str, float], has_excludes: bool +) -> float: + """Compute weighted total score, redistributing weights if no excludes.""" + if not has_excludes: + active_weights = {k: v for k, v in weights.items() if k != "excludes_all"} + weight_sum = sum(active_weights.values()) + return sum(scores[k] * active_weights[k] / weight_sum for k in active_weights) + return sum(scores[k] * weights[k] for k in scores) + + +def compute_reward(args: dict, pred: dspy.Prediction, config: Config) -> float: + """Multi-criteria reward function for dspy.Refine.""" + pattern = getattr(pred, "pattern", "") + exclude_items = args.get("exclude_items", []) + + analysis = PatternAnalysis( + pattern=pattern, + text=args.get("text", ""), + match_items=args.get("match_items", []), + exclude_items=exclude_items, + ) + + if not analysis.is_valid: + if config.debug: + print(f" Pattern: {pattern}") + print(" Invalid syntax - score: 0.0") + return 0.0 + + scores = compute_scores(analysis) + total_score = compute_total_score(scores, config.weights, bool(exclude_items)) + + if config.debug: + print(f" Pattern: {pattern}") + print(f" Scores: {', '.join(f'{k}={v:.2f}' for k, v in scores.items())}") + print(f" Total: {total_score:.3f}") + + return total_score + + +# ============================================================================= +# Pattern Result +# ============================================================================= + + +@dataclass +class PatternResult: + """Holds detailed scoring info for a single pattern.""" + + pattern: str + source: str # 'llm' or 'grex' + is_valid: bool + total_score: float + scores: dict[str, float] + all_matches: list[str] + matched_items: list[str] + missed_items: list[str] + excluded_matched: list[str] + extra_matches: list[str] + + def to_dict(self) -> dict: + """Convert to JSON-serializable dict.""" + return { + "pattern": self.pattern, + "source": self.source, + "is_valid": self.is_valid, + "total_score": round(self.total_score, 4), + "scores": {k: round(v, 4) for k, v in self.scores.items()}, + "all_matches": self.all_matches, + "matched_items": self.matched_items, + "missed_items": self.missed_items, + "excluded_matched": self.excluded_matched, + "extra_matches": self.extra_matches, + } + + +def score_pattern( + pattern: str, + text: str, + match_items: list[str], + exclude_items: list[str], + config: Config, + source: str = "llm", +) -> PatternResult: + """Score a pattern and return detailed results.""" + exclude_items = exclude_items or [] + analysis = PatternAnalysis(pattern, text, match_items, exclude_items) + + if not analysis.is_valid: + return PatternResult( + pattern=pattern, + source=source, + is_valid=False, + total_score=0.0, + scores={k: 0.0 for k in config.weights}, + all_matches=[], + matched_items=[], + missed_items=list(match_items), + excluded_matched=[], + extra_matches=[], + ) + + scores = compute_scores(analysis) + total_score = compute_total_score(scores, config.weights, bool(exclude_items)) + + wanted = set(match_items) + excluded = set(exclude_items) + + return PatternResult( + pattern=pattern, + source=source, + is_valid=True, + total_score=total_score, + scores=scores, + all_matches=list(analysis.matches), + matched_items=[item for item in match_items if item in analysis.matches], + missed_items=[item for item in match_items if item not in analysis.matches], + excluded_matched=[item for item in exclude_items if item in analysis.matches], + extra_matches=[m for m in analysis.matches if m not in wanted and m not in excluded], + ) + + +# ============================================================================= +# Grex Pattern Generation +# ============================================================================= + + +def generate_grex_pattern(match_items: list[str]) -> str | None: + """Generate a baseline pattern using grex.""" + try: + return ( + RegExpBuilder.from_test_cases(match_items) + .with_conversion_of_digits() + .with_conversion_of_words() + .with_conversion_of_repetitions() + .without_anchors() + .build() + ) + except Exception: + return None + + +# ============================================================================= +# Training Data +# ============================================================================= + + +def load_trainset(config: Config) -> list[dspy.Example]: + """Load training examples from file.""" + trainset = [] + + try: + with open(config.dataset_file, "r") as f: + data = json.load(f) + + for item in data: + hints = analyze_match_items( + item["match_items"], item.get("exclude_items", []) + ) + example = dspy.Example( + text=item["text"], + match_items=item["match_items"], + exclude_items=item.get("exclude_items", []), + pattern_hints=format_hints_for_prompt(hints), + pattern=item["expected_pattern"], + ).with_inputs("text", "match_items", "exclude_items", "pattern_hints") + trainset.append(example) + except FileNotFoundError: + if config.debug: + print(f"[WARN] Dataset file not found: {config.dataset_file}") + + return trainset + + +# ============================================================================= +# Compilation +# ============================================================================= + + +def compile_and_save(config: RegexConfig | None = None) -> str | None: + """ + Pre-compile the regex generator using InferRules. + Returns the path to the saved compiled program. + """ + config = config or RegexConfig() + + print("[COMPILE] Setting up DSPy...") + lm = dspy.LM( + f"ollama_chat/{config.model}", + api_base=config.ollama_url, + api_key="", + cache=True, + temperature=config.temperature, + num_ctx=config.num_ctx, + ) + dspy.configure(lm=lm) + + print("[COMPILE] Loading training data...") + trainset = load_trainset(config) + + if not trainset: + print(f"[ERROR] No training data found: {config.dataset_file}") + return None + + print(f"[COMPILE] Loaded {len(trainset)} training examples") + + # Build base module + base_module = build_base_program(config) + + def metric_fn(example, pred, trace=None): # noqa: ARG001 + args = { + "text": example.text, + "match_items": example.match_items, + "exclude_items": example.exclude_items, + } + return compute_reward(args, pred, config) + + print("[COMPILE] Running InferRules...") + print(f" Candidates: {config.compile_candidates}") + print(f" Rules to extract: {config.compile_num_rules}") + print(f" Threads: {config.compile_threads}") + + optimizer = InferRules( + metric=metric_fn, + num_candidates=config.compile_candidates, + num_rules=config.compile_num_rules, + num_threads=config.compile_threads, + ) + + compiled_program = optimizer.compile(student=base_module, trainset=trainset) + + save_path = config.compiled_program_path + compiled_program.save(save_path) + print(f"[COMPILE] Saved compiled program to: {save_path}") + + return save_path + + +def load_compiled_program(config: Config): + """Load a pre-compiled program if it exists.""" + if not os.path.exists(config.compiled_program_path): + return None + + program = build_base_program(config) + + try: + program.load(config.compiled_program_path) + if config.debug: + print(f"[LOADED] Pre-compiled program from {config.compiled_program_path}") + return program + except Exception as e: + if config.debug: + print(f"[WARN] Failed to load compiled program: {e}") + return None + + +# ============================================================================= +# Main Generator +# ============================================================================= + + +def generate_regex(input_data: dict, config: RegexConfig | None = None) -> dict: + """ + Main entry point for regex generation using dspy.Refine. + + Returns a dict with: + - results: list of PatternResult dicts (sorted by score, best first) + - hints_used: pattern hints string + """ + config = config or RegexConfig() + + text = input_data.get("text", "") + match_items = input_data.get("Highlighted Items", []) + exclude_items = input_data.get("Excluded Items", []) + + if not text or not match_items: + raise ValueError("Input must contain 'text' and 'Highlighted Items'") + + hints = analyze_match_items(match_items, exclude_items) + hints_str = format_hints_for_prompt(hints) + + if config.debug: + print(f"\n[PATTERN HINTS] {hints_str}") + + pattern_results: list[PatternResult] = [] + + # Generate grex baseline + grex_pattern = generate_grex_pattern(match_items) + if grex_pattern: + grex_result = score_pattern( + grex_pattern, text, match_items, exclude_items, config, source="grex" + ) + pattern_results.append(grex_result) + if config.debug: + print(f"\n[GREX] {grex_pattern} (score: {grex_result.total_score:.3f})") + + # Setup DSPy + lm = dspy.LM( + f"ollama_chat/{config.model}", + api_base=config.ollama_url, + api_key="", + cache=config.enable_cache, + temperature=config.temperature, + num_ctx=config.num_ctx, + ) + dspy.configure(lm=lm) + + # Load or build module + compiled_module = load_compiled_program(config) + + if compiled_module is None: + trainset = load_trainset(config) + + base_module = build_base_program(config) + + if trainset: + optimizer = LabeledFewShot(k=len(trainset)) + compiled_module = optimizer.compile(student=base_module, trainset=trainset) + else: + compiled_module = base_module + + if config.debug: + print("\n[GENERATING PATTERNS]") + print(f" Match: {match_items}") + print(f" Exclude: {exclude_items}") + print(f" Max attempts: {config.max_attempts}") + print(f" Stop threshold: {config.reward_threshold}") + + def reward_fn(args, pred): + return compute_reward(args, pred, config) + + refine_module = dspy.Refine( + module=compiled_module, + N=config.max_attempts, + reward_fn=reward_fn, + threshold=config.reward_threshold, + fail_count=config.fail_count, + ) + + try: + if config.debug: + print("\n[REFINE] Running iterative refinement...") + + result = refine_module( + text=text, + match_items=match_items, + exclude_items=exclude_items, + pattern_hints=hints_str, + ) + + pattern = safe_unescape(result.pattern) + llm_result = score_pattern( + pattern, text, match_items, exclude_items, config, source="llm" + ) + pattern_results.append(llm_result) + + if config.debug: + print("\n[REFINE RESULT]") + print(f" Pattern: {pattern}") + print(f" Score: {llm_result.total_score:.3f}") + + except Exception as e: + if config.debug: + print(f"\n[REFINE ERROR] {e}") + + # Deduplicate and sort + seen = set() + unique_results = [] + for r in sorted(pattern_results, key=lambda x: x.total_score, reverse=True): + key = (r.pattern, r.source) + if key not in seen: + seen.add(key) + unique_results.append(r) + + return { + "results": [r.to_dict() for r in unique_results], + "hints_used": hints_str, + } + + +# ============================================================================= +# CLI Commands +# ============================================================================= + + +def print_pattern_result(result: dict, verbose: bool = True) -> None: + """Print a single pattern result in a formatted way.""" + pattern = result["pattern"] + source = result["source"].upper() + total = result["total_score"] + scores = result["scores"] + + is_perfect = scores["matches_all"] == 1.0 and scores["excludes_all"] == 1.0 + status = "OK" if is_perfect else "PARTIAL" + + print(f"\n [{source}] {pattern}") + print(f" Status: {status} | Total Score: {total:.2%}") + + if verbose: + score_parts = [f"{k}={v:.0%}" for k, v in scores.items()] + print(f" Scores: {', '.join(score_parts)}") + + if result["matched_items"]: + print(f" Matched: {result['matched_items']}") + if result["missed_items"]: + print(f" Missed: {result['missed_items']}") + if result["excluded_matched"]: + print(f" Bad (matched excludes): {result['excluded_matched']}") + if result["extra_matches"]: + print(f" Extra: {result['extra_matches']}") + + +def run_tests(config: Config) -> None: + """Run test suite with detailed output.""" + test_cases = [ + # Tier 1: Basic capabilities + { + "text": "Items: A1, B2, C3, D4, E5", + "Highlighted Items": ["A1", "B2", "C3"], + "Excluded Items": [], + }, + { + "text": "Mixed: HELLO world GOODBYE earth", + "Highlighted Items": ["HELLO", "GOODBYE"], + "Excluded Items": ["world", "earth"], + }, + { + "text": "IDs: user_01, admin_02, guest_03", + "Highlighted Items": ["user_01", "admin_02", "guest_03"], + "Excluded Items": [], + }, + # Tier 2: Word boundaries + { + "text": "test testing tested retest", + "Highlighted Items": ["test"], + "Excluded Items": ["testing", "tested", "retest"], + }, + { + "text": "ERR-01 ERR-02 ERR-001 ERR-002", + "Highlighted Items": ["ERR-01", "ERR-02"], + "Excluded Items": ["ERR-001", "ERR-002"], + }, + { + "text": "Call 555-1234 or 555-5678, not 555-1234-5678", + "Highlighted Items": ["555-1234", "555-5678"], + "Excluded Items": ["555-1234-5678"], + }, + # Tier 3: Quantifier precision + { + "text": "Codes: A01 A02 A03 A001 A002", + "Highlighted Items": ["A01", "A02", "A03"], + "Excluded Items": ["A001", "A002"], + }, + { + "text": "Numbers: 12, 123, 1234, 12345", + "Highlighted Items": ["123", "1234"], + "Excluded Items": ["12", "12345"], + }, + { + "text": "Versions: v1.0, v2.1, v10.20, v1.0.0", + "Highlighted Items": ["v1.0", "v2.1", "v10.20"], + "Excluded Items": ["v1.0.0"], + }, + # Tier 4: Complex exclusions + { + "text": "Tags: #valid #good #123 #bad# #", + "Highlighted Items": ["#valid", "#good", "#123"], + "Excluded Items": ["#bad#", "#"], + }, + { + "text": "Refs: item-001, thing-002, item_003, thing.004", + "Highlighted Items": ["item-001", "thing-002"], + "Excluded Items": ["item_003", "thing.004"], + }, + { + "text": "Keys: AB12, CD34, ab12, ABCD12", + "Highlighted Items": ["AB12", "CD34"], + "Excluded Items": ["ab12", "ABCD12"], + }, + # Tier 5: Edge cases + { + "text": "Contacts: a@b.co, x@y.io, test@, @bad.com", + "Highlighted Items": ["a@b.co", "x@y.io"], + "Excluded Items": ["test@", "@bad.com"], + }, + { + "text": "Prices: 10.99, 20.00, 5.5, 10.999", + "Highlighted Items": ["10.99", "20.00"], + "Excluded Items": ["5.5", "10.999"], + }, + ] + + tiers = { + "T1 Basic": [0, 1, 2], + "T2 Word Boundaries": [3, 4, 5], + "T3 Quantifiers": [6, 7, 8], + "T4 Complex Exclusions": [9, 10, 11], + "T5 Edge Cases": [12, 13], + } + + print("Running test cases...\n") + results = [] + + for i, test_case in enumerate(test_cases): + print(f"\n{'=' * 70}") + print(f"TEST {i + 1}: {test_case['Highlighted Items']} vs {test_case['Excluded Items']}") + print(f"{'=' * 70}") + print(f"Text: {test_case['text'][:60]}{'...' if len(test_case['text']) > 60 else ''}") + + try: + result = generate_regex(test_case, config) + llm_result = next( + (r for r in result["results"] if r["source"] == "llm"), None + ) + is_pass = ( + llm_result + and llm_result["scores"]["matches_all"] == 1.0 + and llm_result["scores"]["excludes_all"] == 1.0 + ) + + llm_results = [r for r in result["results"] if r["source"] == "llm"] + grex_results = [r for r in result["results"] if r["source"] == "grex"] + for pr in llm_results + grex_results: + print_pattern_result(pr) + + results.append({"test": i + 1, "status": "pass" if is_pass else "partial", "result": result}) + except Exception as e: + print(f"\n FAILED: {e}") + results.append({"test": i + 1, "status": "failed", "error": str(e)}) + + # Summary + print(f"\n{'=' * 70}") + print("SUMMARY") + print(f"{'=' * 70}") + + passed = sum(1 for r in results if r["status"] == "pass") + partial = sum(1 for r in results if r["status"] == "partial") + failed = sum(1 for r in results if r["status"] == "failed") + + print(f"\nOverall: {passed}/{len(results)} passed, {partial} partial, {failed} failed") + + print("\nBy Tier:") + for tier_name, indices in tiers.items(): + tier_passed = sum( + 1 for i in indices if i < len(results) and results[i]["status"] == "pass" + ) + print(f" {tier_name}: {tier_passed}/{len(indices)}") + + failures = [r for r in results if r["status"] != "pass"] + if failures: + print("\nFailed/Partial Tests:") + for r in failures: + test_idx = r["test"] - 1 + tc = test_cases[test_idx] + print(f" Test {r['test']}: {tc['Highlighted Items']} - {r['status']}") + + +def load_config_from_file(config_file: str, base_config: Config) -> Config: + """Load config overrides from a JSON file.""" + try: + with open(config_file, "r", encoding="utf-8") as f: + overrides = json.load(f) + + return Config( + model=overrides.get("model", base_config.model), + temperature=overrides.get("temperature", base_config.temperature), + max_attempts=overrides.get("max_attempts", base_config.max_attempts), + reward_threshold=overrides.get("reward_threshold", base_config.reward_threshold), + weights=overrides.get("weights", base_config.weights), + ) + except Exception as e: + print(f"[WARN] Failed to load config: {e}") + return base_config + + +def cmd_list_dataset(output_file: str, config: Config) -> None: + """Export the training dataset to a JSON file.""" + try: + with open(config.dataset_file, "r", encoding="utf-8") as f: + data = json.load(f) + except FileNotFoundError: + data = [] + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + print(f"Exported {len(data)} examples to {output_file}") + + +def cmd_add_example(example_file: str, config: Config) -> None: + """Add a new example to the training dataset.""" + with open(example_file, "r", encoding="utf-8") as f: + new_example = json.load(f) + + try: + with open(config.dataset_file, "r", encoding="utf-8") as f: + data = json.load(f) + except FileNotFoundError: + data = [] + + data.append(new_example) + + with open(config.dataset_file, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + print(f"Added example. Dataset now has {len(data)} examples.") + + +def cmd_delete_example(index: int, config: Config) -> None: + """Delete an example from the training dataset by index.""" + with open(config.dataset_file, "r", encoding="utf-8") as f: + data = json.load(f) + + if index < 0 or index >= len(data): + print(f"[ERROR] Index {index} out of range (0-{len(data) - 1})") + sys.exit(1) + + data.pop(index) + + with open(config.dataset_file, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + print(f"Deleted example at index {index}. Dataset now has {len(data)} examples.") + + +def cmd_generate(input_file: str, output_file: str, config: Config) -> None: + """Generate regex from input file.""" + with open(input_file, "r", encoding="utf-8") as f: + input_data = json.load(f) + + if config.debug: + print(f"Input: {json.dumps(input_data, indent=2)}") + + try: + result = generate_regex(input_data, config) + except Exception as e: + result = {"error": str(e)} + if config.debug: + print(f"\n[ERROR] {e}") + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(result, f, indent=2) + + if config.debug: + print(f"\nOutput written to: {output_file}") + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + + +def create_parser() -> argparse.ArgumentParser: + """Create argument parser.""" + parser = argparse.ArgumentParser( + description="Regex Generator - Generate regex patterns from examples using DSPy + Ollama", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Test command + subparsers.add_parser("test", help="Run test cases") + + # Compile command + subparsers.add_parser("compile", help="Pre-compile for faster runtime") + + # List dataset command + list_parser = subparsers.add_parser("list-dataset", help="Export training dataset") + list_parser.add_argument("output", help="Output JSON file") + + # Add example command + add_parser = subparsers.add_parser("add-example", help="Add example to dataset") + add_parser.add_argument("example", help="Example JSON file") + + # Delete example command + del_parser = subparsers.add_parser("delete-example", help="Delete example from dataset") + del_parser.add_argument("index", type=int, help="Index to delete") + + # Generate command (positional args for backward compatibility) + gen_parser = subparsers.add_parser("generate", help="Generate regex from input file") + gen_parser.add_argument("input", help="Input JSON file") + gen_parser.add_argument("output", help="Output JSON file") + gen_parser.add_argument("--config", help="Config JSON file") + + return parser + + +def main() -> None: + """CLI entry point.""" + config = Config() + args = sys.argv[1:] + + # Handle legacy positional argument format: [--config ] + if args and not args[0].startswith("-") and not args[0] in [ + "test", "compile", "list-dataset", "add-example", "delete-example", "generate" + ]: + input_file = args[0] + output_file = args[1] if len(args) > 1 else None + + if not output_file: + print("Usage: python -m regspy [--config ]") + sys.exit(1) + + if len(args) == 4 and args[2] == "--config": + config = load_config_from_file(args[3], config) + + cmd_generate(input_file, output_file, config) + return + + # Handle flag-style legacy format + if args and args[0] == "--test": + run_tests(config) + return + if args and args[0] == "--compile": + print("Pre-compiling regex generator...\nThis may take several minutes.\n") + path = compile_and_save(config) + if path: + print("\nDone! Use the compiled program by running normally.") + return + if len(args) >= 2 and args[0] == "--list-dataset": + cmd_list_dataset(args[1], config) + return + if len(args) >= 2 and args[0] == "--add-example": + cmd_add_example(args[1], config) + return + if len(args) >= 2 and args[0] == "--delete-example": + try: + index = int(args[1]) + except ValueError: + print(f"[ERROR] Index must be an integer, got: {args[1]}") + sys.exit(1) + cmd_delete_example(index, config) + return + + # Use argparse for subcommand format + parser = create_parser() + parsed = parser.parse_args(args) + + if parsed.command == "test": + run_tests(config) + elif parsed.command == "compile": + print("Pre-compiling regex generator...\nThis may take several minutes.\n") + path = compile_and_save(config) + if path: + print("\nDone! Use the compiled program by running normally.") + elif parsed.command == "list-dataset": + cmd_list_dataset(parsed.output, config) + elif parsed.command == "add-example": + cmd_add_example(parsed.example, config) + elif parsed.command == "delete-example": + cmd_delete_example(parsed.index, config) + elif parsed.command == "generate": + if parsed.config: + config = load_config_from_file(parsed.config, config) + cmd_generate(parsed.input, parsed.output, config) + else: + parser.print_help() + print("\nLegacy usage:") + print(" python -m regspy [--config ]") + + +if __name__ == "__main__": + main()