"""Core functional programming utilities — pure functions for list/collection operations.""" import hashlib import re from functools import reduce as _reduce from typing import Any, Callable, Dict, List, Optional, Tuple def filter_list(xs: list, pred: Callable) -> list: """Filter list by predicate. Does not mutate the original.""" return [x for x in xs if pred(x)] def map_list(xs: list, fn: Callable) -> list: """Map function over list. Does not mutate the original.""" return [fn(x) for x in xs] def reduce_list(xs: list, initial: Any, fn: Callable) -> Any: """Reduce list with accumulator. fn(acc, x) -> acc.""" return _reduce(fn, xs, initial) def flat_map(xs: list, fn: Callable) -> list: """Map function over list then flatten one level.""" result = [] for x in xs: result.extend(fn(x)) return result def flatten(xss: list) -> list: """Flatten a list of lists one level.""" result = [] for xs in xss: result.extend(xs) return result def chunk(xs: list, size: int) -> list: """Split list into chunks of given size. Last chunk may be smaller.""" if size <= 0: return [] return [xs[i : i + size] for i in range(0, len(xs), size)] def take(xs: list, n: int) -> list: """Take first n elements from list.""" return xs[:n] def drop(xs: list, n: int) -> list: """Drop first n elements from list.""" return xs[n:] def unique(xs: list) -> list: """Remove duplicates preserving order. Uses identity for hashable elements.""" seen = set() result = [] for x in xs: if x not in seen: seen.add(x) result.append(x) return result def group_by(xs: list, key_fn: Callable) -> Dict: """Group elements by key function. Returns dict of key -> list.""" groups: Dict = {} for x in xs: k = key_fn(x) if k not in groups: groups[k] = [] groups[k].append(x) return groups def partition(xs: list, pred: Callable) -> Tuple[list, list]: """Split list into (matches, non_matches) based on predicate.""" matches = [] non_matches = [] for x in xs: if pred(x): matches.append(x) else: non_matches.append(x) return (matches, non_matches) def find(xs: list, pred: Callable) -> Any: """Find first element matching predicate. Returns None if not found.""" for x in xs: if pred(x): return x return None def find_index(xs: list, pred: Callable) -> int: """Find index of first element matching predicate. Returns -1 if not found.""" for i, x in enumerate(xs): if pred(x): return i return -1 def zip_with(xs: list, ys: list, fn: Callable) -> list: """Zip two lists with a combining function. Stops at shorter list.""" return [fn(x, y) for x, y in zip(xs, ys)] def all_of(xs: list, pred: Callable) -> bool: """Return True if all elements match predicate.""" return all(pred(x) for x in xs) def any_of(xs: list, pred: Callable) -> bool: """Return True if any element matches predicate.""" return any(pred(x) for x in xs) def pipe(value: Any, *fns: Callable) -> Any: """Pipe a value through a sequence of functions left-to-right.""" result = value for fn in fns: result = fn(result) return result def compose(*fns: Callable) -> Callable: """Compose functions right-to-left. compose(f, g)(x) == f(g(x)).""" def composed(x: Any) -> Any: result = x for fn in reversed(fns): result = fn(result) return result return composed # ── Tree manipulation ──────────────────────────────────────────────────────── def flatten_tree(structure: Any) -> List[Dict]: """Flatten a hierarchical tree (dict with 'nodes') to a list without children.""" import copy if isinstance(structure, dict): node = copy.deepcopy(structure) node.pop('nodes', None) nodes = [node] for key in list(structure.keys()): if 'nodes' in key: nodes.extend(flatten_tree(structure[key])) return nodes elif isinstance(structure, list): nodes = [] for item in structure: nodes.extend(flatten_tree(item)) return nodes return [] def tree_to_flat_list(structure: Any) -> List[Dict]: """Convert hierarchical tree to flat list preserving DFS order (keeps internal nodes).""" if isinstance(structure, dict): nodes = [structure] if 'nodes' in structure: nodes.extend(tree_to_flat_list(structure['nodes'])) return nodes elif isinstance(structure, list): nodes = [] for item in structure: nodes.extend(tree_to_flat_list(item)) return nodes return [] def get_leaf_nodes(structure: Any) -> List[Dict]: """Extract only leaf nodes (no children) from a hierarchical tree.""" import copy if isinstance(structure, dict): if not structure.get('nodes'): node = copy.deepcopy(structure) node.pop('nodes', None) return [node] leaf_nodes = [] for key in list(structure.keys()): if 'nodes' in key: leaf_nodes.extend(get_leaf_nodes(structure[key])) return leaf_nodes elif isinstance(structure, list): leaf_nodes = [] for item in structure: leaf_nodes.extend(get_leaf_nodes(item)) return leaf_nodes return [] def write_node_ids(data: Any, node_id: int = 0) -> int: """Assign sequential zero-padded IDs (0001, 0002...) to all nodes in a tree. Returns next counter.""" if isinstance(data, dict): data['node_id'] = str(node_id).zfill(4) node_id += 1 for key in list(data.keys()): if 'nodes' in key: node_id = write_node_ids(data[key], node_id) elif isinstance(data, list): for item in data: node_id = write_node_ids(item, node_id) return node_id def list_to_tree(data: List[Dict]) -> List[Dict]: """Convert flat list with structure codes ('1.2.3') to nested tree.""" def get_parent_structure(structure): if not structure: return None parts = str(structure).split('.') return '.'.join(parts[:-1]) if len(parts) > 1 else None nodes = {} root_nodes = [] for item in data: structure = item.get('structure') node = { 'title': item.get('title'), 'start_index': item.get('start_index'), 'end_index': item.get('end_index'), 'nodes': [] } nodes[structure] = node parent_structure = get_parent_structure(structure) if parent_structure and parent_structure in nodes: nodes[parent_structure]['nodes'].append(node) else: root_nodes.append(node) def clean_node(node): if not node['nodes']: del node['nodes'] else: for child in node['nodes']: clean_node(child) return node return [clean_node(node) for node in root_nodes] def remove_tree_fields(data: Any, fields: List[str] = None) -> Any: """Recursively remove specified fields from a tree (dict/list).""" if fields is None: fields = ['text'] if isinstance(data, dict): return {k: remove_tree_fields(v, fields) for k, v in data.items() if k not in fields} elif isinstance(data, list): return [remove_tree_fields(item, fields) for item in data] return data def format_tree_structure(structure: Any, order: List[str] = None) -> Any: """Reorder fields of each node in a tree according to specified key order.""" if not order: return structure if isinstance(structure, dict): if 'nodes' in structure: structure['nodes'] = format_tree_structure(structure['nodes'], order) if not structure.get('nodes'): structure.pop('nodes', None) return {key: structure[key] for key in order if key in structure} elif isinstance(structure, list): return [format_tree_structure(item, order) for item in structure] return structure def create_node_mapping(tree: List[Dict]) -> Dict[str, Dict]: """Create flat dict mapping node_id to node for O(1) lookup.""" mapping = {} def _traverse(nodes): for node in nodes: if node.get('node_id'): mapping[node['node_id']] = node if node.get('nodes'): _traverse(node['nodes']) _traverse(tree) return mapping # ── Text / JSON extraction ─────────────────────────────────────────────────── def extract_json_from_llm(content: str) -> Dict: """Extract and parse JSON from LLM responses. Handles ```json blocks, trailing commas, None->null.""" import json try: start_idx = content.find("```json") if start_idx != -1: start_idx += 7 end_idx = content.rfind("```") json_content = content[start_idx:end_idx].strip() else: json_content = content.strip() json_content = json_content.replace('None', 'null') json_content = json_content.replace('\n', ' ').replace('\r', ' ') json_content = ' '.join(json_content.split()) return json.loads(json_content) except (json.JSONDecodeError, Exception): try: json_content = json_content.replace(',]', ']').replace(',}', '}') return json.loads(json_content) except Exception: return {} def parse_page_range(pages: str) -> List[int]: """Parse page range string ('5-7', '3,8', '12') into sorted list of unique ints.""" result = [] for part in pages.split(','): part = part.strip() if '-' in part: start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip()) if start > end: raise ValueError(f"Invalid range '{part}': start must be <= end") result.extend(range(start, end + 1)) else: result.append(int(part)) return sorted(set(result)) # ── Markdown parsing ───────────────────────────────────────────────────────── def extract_markdown_headers(markdown_content: str) -> Tuple[List[Dict], List[str]]: """Extract all headers (h1-h6) from markdown with line numbers, skipping code blocks.""" import re header_pattern = r'^(#{1,6})\s+(.+)$' code_block_pattern = r'^```' node_list = [] lines = markdown_content.split('\n') in_code_block = False for line_num, line in enumerate(lines, 1): stripped_line = line.strip() if re.match(code_block_pattern, stripped_line): in_code_block = not in_code_block continue if not stripped_line: continue if not in_code_block: match = re.match(header_pattern, stripped_line) if match: level = len(match.group(1)) title = match.group(2).strip() node_list.append({'title': title, 'level': level, 'line_num': line_num}) return node_list, lines def build_tree_from_headers(node_list: List[Dict]) -> List[Dict]: """Build nested tree from flat list of headers with levels (h1>h2>h3).""" if not node_list: return [] stack = [] root_nodes = [] node_counter = 1 for node in node_list: current_level = node['level'] tree_node = { 'title': node['title'], 'node_id': str(node_counter).zfill(4), 'line_num': node['line_num'], 'nodes': [] } node_counter += 1 while stack and stack[-1][1] >= current_level: stack.pop() if not stack: root_nodes.append(tree_node) else: parent_node, _ = stack[-1] parent_node['nodes'].append(tree_node) stack.append((tree_node, current_level)) def clean_empty_nodes(nodes): for n in nodes: if n['nodes']: clean_empty_nodes(n['nodes']) else: del n['nodes'] return nodes return clean_empty_nodes(root_nodes) # ── Pagination / chunking ──────────────────────────────────────────────────── def page_list_to_groups(page_contents: List[str], token_lengths: List[int], max_tokens: int = 20000, overlap_pages: int = 1) -> List[str]: """Group pages into text chunks respecting token limit with configurable overlap.""" import math num_tokens = sum(token_lengths) if num_tokens <= max_tokens: return ["".join(page_contents)] subsets = [] current_subset = [] current_token_count = 0 expected_parts = math.ceil(num_tokens / max_tokens) avg_tokens = math.ceil(((num_tokens / expected_parts) + max_tokens) / 2) for i, (page_content, page_tokens) in enumerate(zip(page_contents, token_lengths)): if current_token_count + page_tokens > avg_tokens: subsets.append(''.join(current_subset)) overlap_start = max(i - overlap_pages, 0) current_subset = list(page_contents[overlap_start:i]) current_token_count = sum(token_lengths[overlap_start:i]) current_subset.append(page_content) current_token_count += page_tokens if current_subset: subsets.append(''.join(current_subset)) return subsets def calculate_page_offset(pairs: List[Dict]) -> int: """Calculate offset between logical page numbers and physical indices using reference pairs.""" differences = [] for pair in pairs: try: difference = pair['physical_index'] - pair['page'] differences.append(difference) except (KeyError, TypeError): continue if not differences: return 0 counts: Dict[int, int] = {} for diff in differences: counts[diff] = counts.get(diff, 0) + 1 return max(counts.items(), key=lambda x: x[1])[0] # ── Text preprocessing ─────────────────────────────────────────────────────── def preprocess_text(text: str) -> str: """Normalize whitespace and newlines in raw text. Args: text: Raw text to normalize. Returns: Normalized text with consistent newlines, stripped lines, and no excessive blank lines. """ # Normalize line endings: \r\n and \r -> \n text = text.replace('\r\n', '\n').replace('\r', '\n') # Reduce 3+ consecutive newlines to at most 2 text = re.sub(r'\n{3,}', '\n\n', text) # Strip whitespace from each line text = '\n'.join(line.strip() for line in text.split('\n')) # Strip globally return text.strip() def get_text_stats(text: str) -> dict: """Compute basic statistics of a text: characters, lines, words. Args: text: Input text to analyze. Returns: Dict with keys total_chars (int), total_lines (int), total_words (int). """ return { 'total_chars': len(text), 'total_lines': text.count('\n') + 1, 'total_words': len(text.split()), } # ── Git URL parsing ────────────────────────────────────────────────────────── _DEFAULT_GIT_HOSTS = ["github.com", "gitlab.com"] def _sanitize_git_segment(segment: str) -> str: """Strip .git suffix then keep only [a-zA-Z0-9_-] chars.""" if segment.endswith(".git"): segment = segment[:-4] return re.sub(r"[^a-zA-Z0-9_\-]", "", segment) def parse_git_url(url: str, known_hosts: Optional[List[str]] = None) -> Optional[str]: """Parse a code-hosting URL and return the 'org/repo' path component. Supports HTTPS, HTTP, git://, ssh:// and SSH shorthand (git@host:path). Returns None if the URL does not match any known host or is malformed. Args: url: Repository URL in any supported format. known_hosts: List of accepted hostnames. Defaults to github.com and gitlab.com. Returns: 'org/repo' string or None. """ from urllib.parse import urlparse hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS url = url.strip() if url.startswith("git@"): # git@github.com:org/repo.git rest = url[len("git@"):] if ":" not in rest: return None host, path = rest.split(":", 1) if host not in hosts: return None segments = [s for s in path.split("/") if s] if len(segments) < 2: return None org = _sanitize_git_segment(segments[0]) repo = _sanitize_git_segment(segments[1]) if not org or not repo: return None return f"{org}/{repo}" for prefix in ("http://", "https://", "git://", "ssh://"): if url.startswith(prefix): parsed = urlparse(url) netloc = parsed.hostname or "" if netloc not in hosts: return None segments = [s for s in parsed.path.split("/") if s] if len(segments) < 2: return None org = _sanitize_git_segment(segments[0]) repo = _sanitize_git_segment(segments[1]) if not org or not repo: return None return f"{org}/{repo}" return None def is_git_repo_url(url: str, known_hosts: Optional[List[str]] = None) -> bool: """Return True only if url points to a clonable git repository. Accepts org/repo and org/repo/tree/ paths. Rejects paths that navigate to sub-resources (issues, blobs, PRs, etc.). Args: url: URL to verify. known_hosts: Accepted hostnames. Defaults to github.com and gitlab.com. Returns: True if url is a clonable repository URL. """ from urllib.parse import urlparse hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS url = url.strip() # SSH shorthand — always repo-level if host matches if url.startswith("git@"): rest = url[len("git@"):] if ":" not in rest: return False host, _ = rest.split(":", 1) return host in hosts # git:// and ssh:// — always repo-level if host matches for prefix in ("ssh://", "git://"): if url.startswith(prefix): parsed = urlparse(url) return (parsed.hostname or "") in hosts # http:// and https:// — must have exactly org/repo or org/repo/tree/ for prefix in ("http://", "https://"): if url.startswith(prefix): parsed = urlparse(url) if (parsed.hostname or "") not in hosts: return False segments = [s for s in parsed.path.split("/") if s] if len(segments) == 2: return True if len(segments) == 4 and segments[2] == "tree": return True return False return False def validate_git_ssh_uri(url: str) -> None: """Validate a git SSH URI of the form git@host:path. Raises ValueError with a descriptive message if the URI is malformed. Args: url: URI string to validate. Raises: ValueError: If the URI does not conform to git SSH format. """ if not url.startswith("git@"): raise ValueError(f"git SSH URI must start with 'git@', got: {url!r}") rest = url[len("git@"):] if ":" not in rest: raise ValueError(f"git SSH URI must contain ':', got: {url!r}") _, path = rest.split(":", 1) if not path: raise ValueError(f"git SSH URI must have a non-empty path after ':', got: {url!r}") # --------------------------------------------------------------------------- # Markdown parsing utilities # --------------------------------------------------------------------------- def extract_frontmatter(content: str) -> Tuple[str, Optional[Dict]]: """Extract YAML frontmatter delimited by '---' from the start of a markdown string. Args: content: Raw markdown string, optionally starting with YAML frontmatter. Returns: Tuple of (content_without_frontmatter, frontmatter_dict). frontmatter_dict is None when no frontmatter is found. """ pattern = re.compile(r'^---\n(.*?)\n---\n', re.DOTALL) match = pattern.match(content) if not match: return content, None raw = match.group(1) remaining = content[match.end():] try: import yaml # type: ignore data = yaml.safe_load(raw) if not isinstance(data, dict): data = None except Exception: # Fallback: simple key: value parser (no yaml dependency) data = {} for line in raw.splitlines(): if ':' in line: key, _, value = line.partition(':') data[key.strip()] = value.strip() return remaining, data def find_headings(content: str) -> List[Tuple[int, int, str, int]]: """Find all markdown headings (# to ######), excluding those inside code blocks, HTML comments, and indented blocks. Args: content: Markdown text to search. Returns: List of (start_pos, end_pos, title, level) for each heading found. """ excluded: List[Tuple[int, int]] = [] # Code blocks (triple backtick) for m in re.finditer(r'```.*?```', content, re.DOTALL): excluded.append((m.start(), m.end())) # HTML comments for m in re.finditer(r'', content, re.DOTALL): excluded.append((m.start(), m.end())) # Indented blocks (lines starting with 4 spaces or a tab) for m in re.finditer(r'^( |\t).+$', content, re.MULTILINE): excluded.append((m.start(), m.end())) def is_excluded(pos: int) -> bool: return any(start <= pos < end for start, end in excluded) results: List[Tuple[int, int, str, int]] = [] for m in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE): # Skip escaped headings (\#) before = content[m.start() - 1] if m.start() > 0 else '' if before == '\\': continue if is_excluded(m.start()): continue level = len(m.group(1)) title = m.group(2).strip() results.append((m.start(), m.end(), title, level)) return results def estimate_token_count(content: str) -> int: """Estimate token count without a tokenizer. CJK characters count as ~0.7 tokens each; other non-whitespace characters count as ~0.3 tokens each. Args: content: Text to estimate. Returns: Estimated integer token count. """ cjk = re.findall(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', content) without_cjk = re.sub(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', '', content) others = re.findall(r'\S', without_cjk) return int(len(cjk) * 0.7 + len(others) * 0.3) def smart_split_content( content: str, max_tokens: int = 1024, max_chars: int = 8000, ) -> List[str]: """Split large content into parts respecting token and character limits. Splits by paragraphs (double newline). If a single paragraph exceeds the limit it is force-cut into chunks of max_chars. Args: content: Text to split. max_tokens: Maximum estimated tokens per part. max_chars: Maximum characters per part. Returns: List of string parts. """ paragraphs = content.split('\n\n') parts: List[str] = [] current_parts: List[str] = [] current_tokens = 0 current_chars = 0 def flush() -> None: if current_parts: parts.append('\n\n'.join(current_parts)) current_parts.clear() for para in paragraphs: para_tokens = estimate_token_count(para) para_chars = len(para) # Single paragraph exceeds limits — force-cut it if para_tokens > max_tokens or para_chars > max_chars: flush() current_tokens = 0 current_chars = 0 for i in range(0, len(para), max_chars): parts.append(para[i:i + max_chars]) continue # Would exceed limits if added — flush first if (current_tokens + para_tokens > max_tokens or current_chars + para_chars > max_chars): flush() current_tokens = 0 current_chars = 0 current_parts.append(para) current_tokens += para_tokens current_chars += para_chars flush() return parts if parts else [content] def sanitize_for_path(text: str, max_length: int = 50) -> str: """Convert text to a safe string for use in file paths. Keeps word characters, CJK characters, spaces and hyphens. Replaces spaces with underscores. Truncates with a sha256 suffix if the result exceeds max_length. Args: text: Input text to sanitize. max_length: Maximum length of the returned string. Returns: Safe path-friendly string. """ cleaned = re.sub( r'[^\w\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af \-]', '', text, ) cleaned = cleaned.replace(' ', '_').strip('_') if not cleaned: return 'section' if len(cleaned) <= max_length: return cleaned suffix = '_' + hashlib.sha256(text.encode()).hexdigest()[:8] return cleaned[:max_length - len(suffix)] + suffix