Files
ontology_graph/lib/core_functions.py
fn-registry agent 40bea81603 chore: initial sync
2026-04-28 22:13:08 +02:00

815 lines
26 KiB
Python

"""Core functional programming utilities — pure functions for list/collection operations."""
import hashlib
import re
from functools import reduce as _reduce
from typing import Any, Callable, Dict, List, Optional, Tuple
def filter_list(xs: list, pred: Callable) -> list:
"""Filter list by predicate. Does not mutate the original."""
return [x for x in xs if pred(x)]
def map_list(xs: list, fn: Callable) -> list:
"""Map function over list. Does not mutate the original."""
return [fn(x) for x in xs]
def reduce_list(xs: list, initial: Any, fn: Callable) -> Any:
"""Reduce list with accumulator. fn(acc, x) -> acc."""
return _reduce(fn, xs, initial)
def flat_map(xs: list, fn: Callable) -> list:
"""Map function over list then flatten one level."""
result = []
for x in xs:
result.extend(fn(x))
return result
def flatten(xss: list) -> list:
"""Flatten a list of lists one level."""
result = []
for xs in xss:
result.extend(xs)
return result
def chunk(xs: list, size: int) -> list:
"""Split list into chunks of given size. Last chunk may be smaller."""
if size <= 0:
return []
return [xs[i : i + size] for i in range(0, len(xs), size)]
def take(xs: list, n: int) -> list:
"""Take first n elements from list."""
return xs[:n]
def drop(xs: list, n: int) -> list:
"""Drop first n elements from list."""
return xs[n:]
def unique(xs: list) -> list:
"""Remove duplicates preserving order. Uses identity for hashable elements."""
seen = set()
result = []
for x in xs:
if x not in seen:
seen.add(x)
result.append(x)
return result
def group_by(xs: list, key_fn: Callable) -> Dict:
"""Group elements by key function. Returns dict of key -> list."""
groups: Dict = {}
for x in xs:
k = key_fn(x)
if k not in groups:
groups[k] = []
groups[k].append(x)
return groups
def partition(xs: list, pred: Callable) -> Tuple[list, list]:
"""Split list into (matches, non_matches) based on predicate."""
matches = []
non_matches = []
for x in xs:
if pred(x):
matches.append(x)
else:
non_matches.append(x)
return (matches, non_matches)
def find(xs: list, pred: Callable) -> Any:
"""Find first element matching predicate. Returns None if not found."""
for x in xs:
if pred(x):
return x
return None
def find_index(xs: list, pred: Callable) -> int:
"""Find index of first element matching predicate. Returns -1 if not found."""
for i, x in enumerate(xs):
if pred(x):
return i
return -1
def zip_with(xs: list, ys: list, fn: Callable) -> list:
"""Zip two lists with a combining function. Stops at shorter list."""
return [fn(x, y) for x, y in zip(xs, ys)]
def all_of(xs: list, pred: Callable) -> bool:
"""Return True if all elements match predicate."""
return all(pred(x) for x in xs)
def any_of(xs: list, pred: Callable) -> bool:
"""Return True if any element matches predicate."""
return any(pred(x) for x in xs)
def pipe(value: Any, *fns: Callable) -> Any:
"""Pipe a value through a sequence of functions left-to-right."""
result = value
for fn in fns:
result = fn(result)
return result
def compose(*fns: Callable) -> Callable:
"""Compose functions right-to-left. compose(f, g)(x) == f(g(x))."""
def composed(x: Any) -> Any:
result = x
for fn in reversed(fns):
result = fn(result)
return result
return composed
# ── Tree manipulation ────────────────────────────────────────────────────────
def flatten_tree(structure: Any) -> List[Dict]:
"""Flatten a hierarchical tree (dict with 'nodes') to a list without children."""
import copy
if isinstance(structure, dict):
node = copy.deepcopy(structure)
node.pop('nodes', None)
nodes = [node]
for key in list(structure.keys()):
if 'nodes' in key:
nodes.extend(flatten_tree(structure[key]))
return nodes
elif isinstance(structure, list):
nodes = []
for item in structure:
nodes.extend(flatten_tree(item))
return nodes
return []
def tree_to_flat_list(structure: Any) -> List[Dict]:
"""Convert hierarchical tree to flat list preserving DFS order (keeps internal nodes)."""
if isinstance(structure, dict):
nodes = [structure]
if 'nodes' in structure:
nodes.extend(tree_to_flat_list(structure['nodes']))
return nodes
elif isinstance(structure, list):
nodes = []
for item in structure:
nodes.extend(tree_to_flat_list(item))
return nodes
return []
def get_leaf_nodes(structure: Any) -> List[Dict]:
"""Extract only leaf nodes (no children) from a hierarchical tree."""
import copy
if isinstance(structure, dict):
if not structure.get('nodes'):
node = copy.deepcopy(structure)
node.pop('nodes', None)
return [node]
leaf_nodes = []
for key in list(structure.keys()):
if 'nodes' in key:
leaf_nodes.extend(get_leaf_nodes(structure[key]))
return leaf_nodes
elif isinstance(structure, list):
leaf_nodes = []
for item in structure:
leaf_nodes.extend(get_leaf_nodes(item))
return leaf_nodes
return []
def write_node_ids(data: Any, node_id: int = 0) -> int:
"""Assign sequential zero-padded IDs (0001, 0002...) to all nodes in a tree. Returns next counter."""
if isinstance(data, dict):
data['node_id'] = str(node_id).zfill(4)
node_id += 1
for key in list(data.keys()):
if 'nodes' in key:
node_id = write_node_ids(data[key], node_id)
elif isinstance(data, list):
for item in data:
node_id = write_node_ids(item, node_id)
return node_id
def list_to_tree(data: List[Dict]) -> List[Dict]:
"""Convert flat list with structure codes ('1.2.3') to nested tree."""
def get_parent_structure(structure):
if not structure:
return None
parts = str(structure).split('.')
return '.'.join(parts[:-1]) if len(parts) > 1 else None
nodes = {}
root_nodes = []
for item in data:
structure = item.get('structure')
node = {
'title': item.get('title'),
'start_index': item.get('start_index'),
'end_index': item.get('end_index'),
'nodes': []
}
nodes[structure] = node
parent_structure = get_parent_structure(structure)
if parent_structure and parent_structure in nodes:
nodes[parent_structure]['nodes'].append(node)
else:
root_nodes.append(node)
def clean_node(node):
if not node['nodes']:
del node['nodes']
else:
for child in node['nodes']:
clean_node(child)
return node
return [clean_node(node) for node in root_nodes]
def remove_tree_fields(data: Any, fields: List[str] = None) -> Any:
"""Recursively remove specified fields from a tree (dict/list)."""
if fields is None:
fields = ['text']
if isinstance(data, dict):
return {k: remove_tree_fields(v, fields) for k, v in data.items() if k not in fields}
elif isinstance(data, list):
return [remove_tree_fields(item, fields) for item in data]
return data
def format_tree_structure(structure: Any, order: List[str] = None) -> Any:
"""Reorder fields of each node in a tree according to specified key order."""
if not order:
return structure
if isinstance(structure, dict):
if 'nodes' in structure:
structure['nodes'] = format_tree_structure(structure['nodes'], order)
if not structure.get('nodes'):
structure.pop('nodes', None)
return {key: structure[key] for key in order if key in structure}
elif isinstance(structure, list):
return [format_tree_structure(item, order) for item in structure]
return structure
def create_node_mapping(tree: List[Dict]) -> Dict[str, Dict]:
"""Create flat dict mapping node_id to node for O(1) lookup."""
mapping = {}
def _traverse(nodes):
for node in nodes:
if node.get('node_id'):
mapping[node['node_id']] = node
if node.get('nodes'):
_traverse(node['nodes'])
_traverse(tree)
return mapping
# ── Text / JSON extraction ───────────────────────────────────────────────────
def extract_json_from_llm(content: str) -> Dict:
"""Extract and parse JSON from LLM responses. Handles ```json blocks, trailing commas, None->null."""
import json
try:
start_idx = content.find("```json")
if start_idx != -1:
start_idx += 7
end_idx = content.rfind("```")
json_content = content[start_idx:end_idx].strip()
else:
json_content = content.strip()
json_content = json_content.replace('None', 'null')
json_content = json_content.replace('\n', ' ').replace('\r', ' ')
json_content = ' '.join(json_content.split())
return json.loads(json_content)
except (json.JSONDecodeError, Exception):
try:
json_content = json_content.replace(',]', ']').replace(',}', '}')
return json.loads(json_content)
except Exception:
return {}
def parse_page_range(pages: str) -> List[int]:
"""Parse page range string ('5-7', '3,8', '12') into sorted list of unique ints."""
result = []
for part in pages.split(','):
part = part.strip()
if '-' in part:
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
if start > end:
raise ValueError(f"Invalid range '{part}': start must be <= end")
result.extend(range(start, end + 1))
else:
result.append(int(part))
return sorted(set(result))
# ── Markdown parsing ─────────────────────────────────────────────────────────
def extract_markdown_headers(markdown_content: str) -> Tuple[List[Dict], List[str]]:
"""Extract all headers (h1-h6) from markdown with line numbers, skipping code blocks."""
import re
header_pattern = r'^(#{1,6})\s+(.+)$'
code_block_pattern = r'^```'
node_list = []
lines = markdown_content.split('\n')
in_code_block = False
for line_num, line in enumerate(lines, 1):
stripped_line = line.strip()
if re.match(code_block_pattern, stripped_line):
in_code_block = not in_code_block
continue
if not stripped_line:
continue
if not in_code_block:
match = re.match(header_pattern, stripped_line)
if match:
level = len(match.group(1))
title = match.group(2).strip()
node_list.append({'title': title, 'level': level, 'line_num': line_num})
return node_list, lines
def build_tree_from_headers(node_list: List[Dict]) -> List[Dict]:
"""Build nested tree from flat list of headers with levels (h1>h2>h3)."""
if not node_list:
return []
stack = []
root_nodes = []
node_counter = 1
for node in node_list:
current_level = node['level']
tree_node = {
'title': node['title'],
'node_id': str(node_counter).zfill(4),
'line_num': node['line_num'],
'nodes': []
}
node_counter += 1
while stack and stack[-1][1] >= current_level:
stack.pop()
if not stack:
root_nodes.append(tree_node)
else:
parent_node, _ = stack[-1]
parent_node['nodes'].append(tree_node)
stack.append((tree_node, current_level))
def clean_empty_nodes(nodes):
for n in nodes:
if n['nodes']:
clean_empty_nodes(n['nodes'])
else:
del n['nodes']
return nodes
return clean_empty_nodes(root_nodes)
# ── Pagination / chunking ────────────────────────────────────────────────────
def page_list_to_groups(page_contents: List[str], token_lengths: List[int],
max_tokens: int = 20000, overlap_pages: int = 1) -> List[str]:
"""Group pages into text chunks respecting token limit with configurable overlap."""
import math
num_tokens = sum(token_lengths)
if num_tokens <= max_tokens:
return ["".join(page_contents)]
subsets = []
current_subset = []
current_token_count = 0
expected_parts = math.ceil(num_tokens / max_tokens)
avg_tokens = math.ceil(((num_tokens / expected_parts) + max_tokens) / 2)
for i, (page_content, page_tokens) in enumerate(zip(page_contents, token_lengths)):
if current_token_count + page_tokens > avg_tokens:
subsets.append(''.join(current_subset))
overlap_start = max(i - overlap_pages, 0)
current_subset = list(page_contents[overlap_start:i])
current_token_count = sum(token_lengths[overlap_start:i])
current_subset.append(page_content)
current_token_count += page_tokens
if current_subset:
subsets.append(''.join(current_subset))
return subsets
def calculate_page_offset(pairs: List[Dict]) -> int:
"""Calculate offset between logical page numbers and physical indices using reference pairs."""
differences = []
for pair in pairs:
try:
difference = pair['physical_index'] - pair['page']
differences.append(difference)
except (KeyError, TypeError):
continue
if not differences:
return 0
counts: Dict[int, int] = {}
for diff in differences:
counts[diff] = counts.get(diff, 0) + 1
return max(counts.items(), key=lambda x: x[1])[0]
# ── Text preprocessing ───────────────────────────────────────────────────────
def preprocess_text(text: str) -> str:
"""Normalize whitespace and newlines in raw text.
Args:
text: Raw text to normalize.
Returns:
Normalized text with consistent newlines, stripped lines, and no
excessive blank lines.
"""
# Normalize line endings: \r\n and \r -> \n
text = text.replace('\r\n', '\n').replace('\r', '\n')
# Reduce 3+ consecutive newlines to at most 2
text = re.sub(r'\n{3,}', '\n\n', text)
# Strip whitespace from each line
text = '\n'.join(line.strip() for line in text.split('\n'))
# Strip globally
return text.strip()
def get_text_stats(text: str) -> dict:
"""Compute basic statistics of a text: characters, lines, words.
Args:
text: Input text to analyze.
Returns:
Dict with keys total_chars (int), total_lines (int), total_words (int).
"""
return {
'total_chars': len(text),
'total_lines': text.count('\n') + 1,
'total_words': len(text.split()),
}
# ── Git URL parsing ──────────────────────────────────────────────────────────
_DEFAULT_GIT_HOSTS = ["github.com", "gitlab.com"]
def _sanitize_git_segment(segment: str) -> str:
"""Strip .git suffix then keep only [a-zA-Z0-9_-] chars."""
if segment.endswith(".git"):
segment = segment[:-4]
return re.sub(r"[^a-zA-Z0-9_\-]", "", segment)
def parse_git_url(url: str, known_hosts: Optional[List[str]] = None) -> Optional[str]:
"""Parse a code-hosting URL and return the 'org/repo' path component.
Supports HTTPS, HTTP, git://, ssh:// and SSH shorthand (git@host:path).
Returns None if the URL does not match any known host or is malformed.
Args:
url: Repository URL in any supported format.
known_hosts: List of accepted hostnames. Defaults to github.com and gitlab.com.
Returns:
'org/repo' string or None.
"""
from urllib.parse import urlparse
hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS
url = url.strip()
if url.startswith("git@"):
# git@github.com:org/repo.git
rest = url[len("git@"):]
if ":" not in rest:
return None
host, path = rest.split(":", 1)
if host not in hosts:
return None
segments = [s for s in path.split("/") if s]
if len(segments) < 2:
return None
org = _sanitize_git_segment(segments[0])
repo = _sanitize_git_segment(segments[1])
if not org or not repo:
return None
return f"{org}/{repo}"
for prefix in ("http://", "https://", "git://", "ssh://"):
if url.startswith(prefix):
parsed = urlparse(url)
netloc = parsed.hostname or ""
if netloc not in hosts:
return None
segments = [s for s in parsed.path.split("/") if s]
if len(segments) < 2:
return None
org = _sanitize_git_segment(segments[0])
repo = _sanitize_git_segment(segments[1])
if not org or not repo:
return None
return f"{org}/{repo}"
return None
def is_git_repo_url(url: str, known_hosts: Optional[List[str]] = None) -> bool:
"""Return True only if url points to a clonable git repository.
Accepts org/repo and org/repo/tree/<ref> paths.
Rejects paths that navigate to sub-resources (issues, blobs, PRs, etc.).
Args:
url: URL to verify.
known_hosts: Accepted hostnames. Defaults to github.com and gitlab.com.
Returns:
True if url is a clonable repository URL.
"""
from urllib.parse import urlparse
hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS
url = url.strip()
# SSH shorthand — always repo-level if host matches
if url.startswith("git@"):
rest = url[len("git@"):]
if ":" not in rest:
return False
host, _ = rest.split(":", 1)
return host in hosts
# git:// and ssh:// — always repo-level if host matches
for prefix in ("ssh://", "git://"):
if url.startswith(prefix):
parsed = urlparse(url)
return (parsed.hostname or "") in hosts
# http:// and https:// — must have exactly org/repo or org/repo/tree/<ref>
for prefix in ("http://", "https://"):
if url.startswith(prefix):
parsed = urlparse(url)
if (parsed.hostname or "") not in hosts:
return False
segments = [s for s in parsed.path.split("/") if s]
if len(segments) == 2:
return True
if len(segments) == 4 and segments[2] == "tree":
return True
return False
return False
def validate_git_ssh_uri(url: str) -> None:
"""Validate a git SSH URI of the form git@host:path.
Raises ValueError with a descriptive message if the URI is malformed.
Args:
url: URI string to validate.
Raises:
ValueError: If the URI does not conform to git SSH format.
"""
if not url.startswith("git@"):
raise ValueError(f"git SSH URI must start with 'git@', got: {url!r}")
rest = url[len("git@"):]
if ":" not in rest:
raise ValueError(f"git SSH URI must contain ':', got: {url!r}")
_, path = rest.split(":", 1)
if not path:
raise ValueError(f"git SSH URI must have a non-empty path after ':', got: {url!r}")
# ---------------------------------------------------------------------------
# Markdown parsing utilities
# ---------------------------------------------------------------------------
def extract_frontmatter(content: str) -> Tuple[str, Optional[Dict]]:
"""Extract YAML frontmatter delimited by '---' from the start of a markdown string.
Args:
content: Raw markdown string, optionally starting with YAML frontmatter.
Returns:
Tuple of (content_without_frontmatter, frontmatter_dict).
frontmatter_dict is None when no frontmatter is found.
"""
pattern = re.compile(r'^---\n(.*?)\n---\n', re.DOTALL)
match = pattern.match(content)
if not match:
return content, None
raw = match.group(1)
remaining = content[match.end():]
try:
import yaml # type: ignore
data = yaml.safe_load(raw)
if not isinstance(data, dict):
data = None
except Exception:
# Fallback: simple key: value parser (no yaml dependency)
data = {}
for line in raw.splitlines():
if ':' in line:
key, _, value = line.partition(':')
data[key.strip()] = value.strip()
return remaining, data
def find_headings(content: str) -> List[Tuple[int, int, str, int]]:
"""Find all markdown headings (# to ######), excluding those inside code blocks,
HTML comments, and indented blocks.
Args:
content: Markdown text to search.
Returns:
List of (start_pos, end_pos, title, level) for each heading found.
"""
excluded: List[Tuple[int, int]] = []
# Code blocks (triple backtick)
for m in re.finditer(r'```.*?```', content, re.DOTALL):
excluded.append((m.start(), m.end()))
# HTML comments
for m in re.finditer(r'<!--.*?-->', content, re.DOTALL):
excluded.append((m.start(), m.end()))
# Indented blocks (lines starting with 4 spaces or a tab)
for m in re.finditer(r'^( |\t).+$', content, re.MULTILINE):
excluded.append((m.start(), m.end()))
def is_excluded(pos: int) -> bool:
return any(start <= pos < end for start, end in excluded)
results: List[Tuple[int, int, str, int]] = []
for m in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
# Skip escaped headings (\#)
before = content[m.start() - 1] if m.start() > 0 else ''
if before == '\\':
continue
if is_excluded(m.start()):
continue
level = len(m.group(1))
title = m.group(2).strip()
results.append((m.start(), m.end(), title, level))
return results
def estimate_token_count(content: str) -> int:
"""Estimate token count without a tokenizer.
CJK characters count as ~0.7 tokens each; other non-whitespace characters
count as ~0.3 tokens each.
Args:
content: Text to estimate.
Returns:
Estimated integer token count.
"""
cjk = re.findall(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', content)
without_cjk = re.sub(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', '', content)
others = re.findall(r'\S', without_cjk)
return int(len(cjk) * 0.7 + len(others) * 0.3)
def smart_split_content(
content: str,
max_tokens: int = 1024,
max_chars: int = 8000,
) -> List[str]:
"""Split large content into parts respecting token and character limits.
Splits by paragraphs (double newline). If a single paragraph exceeds the
limit it is force-cut into chunks of max_chars.
Args:
content: Text to split.
max_tokens: Maximum estimated tokens per part.
max_chars: Maximum characters per part.
Returns:
List of string parts.
"""
paragraphs = content.split('\n\n')
parts: List[str] = []
current_parts: List[str] = []
current_tokens = 0
current_chars = 0
def flush() -> None:
if current_parts:
parts.append('\n\n'.join(current_parts))
current_parts.clear()
for para in paragraphs:
para_tokens = estimate_token_count(para)
para_chars = len(para)
# Single paragraph exceeds limits — force-cut it
if para_tokens > max_tokens or para_chars > max_chars:
flush()
current_tokens = 0
current_chars = 0
for i in range(0, len(para), max_chars):
parts.append(para[i:i + max_chars])
continue
# Would exceed limits if added — flush first
if (current_tokens + para_tokens > max_tokens or
current_chars + para_chars > max_chars):
flush()
current_tokens = 0
current_chars = 0
current_parts.append(para)
current_tokens += para_tokens
current_chars += para_chars
flush()
return parts if parts else [content]
def sanitize_for_path(text: str, max_length: int = 50) -> str:
"""Convert text to a safe string for use in file paths.
Keeps word characters, CJK characters, spaces and hyphens. Replaces spaces
with underscores. Truncates with a sha256 suffix if the result exceeds
max_length.
Args:
text: Input text to sanitize.
max_length: Maximum length of the returned string.
Returns:
Safe path-friendly string.
"""
cleaned = re.sub(
r'[^\w\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af \-]',
'',
text,
)
cleaned = cleaned.replace(' ', '_').strip('_')
if not cleaned:
return 'section'
if len(cleaned) <= max_length:
return cleaned
suffix = '_' + hashlib.sha256(text.encode()).hexdigest()[:8]
return cleaned[:max_length - len(suffix)] + suffix