chore: initial sync
This commit is contained in:
@@ -0,0 +1,814 @@
|
||||
"""Core functional programming utilities — pure functions for list/collection operations."""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from functools import reduce as _reduce
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def filter_list(xs: list, pred: Callable) -> list:
|
||||
"""Filter list by predicate. Does not mutate the original."""
|
||||
return [x for x in xs if pred(x)]
|
||||
|
||||
|
||||
def map_list(xs: list, fn: Callable) -> list:
|
||||
"""Map function over list. Does not mutate the original."""
|
||||
return [fn(x) for x in xs]
|
||||
|
||||
|
||||
def reduce_list(xs: list, initial: Any, fn: Callable) -> Any:
|
||||
"""Reduce list with accumulator. fn(acc, x) -> acc."""
|
||||
return _reduce(fn, xs, initial)
|
||||
|
||||
|
||||
def flat_map(xs: list, fn: Callable) -> list:
|
||||
"""Map function over list then flatten one level."""
|
||||
result = []
|
||||
for x in xs:
|
||||
result.extend(fn(x))
|
||||
return result
|
||||
|
||||
|
||||
def flatten(xss: list) -> list:
|
||||
"""Flatten a list of lists one level."""
|
||||
result = []
|
||||
for xs in xss:
|
||||
result.extend(xs)
|
||||
return result
|
||||
|
||||
|
||||
def chunk(xs: list, size: int) -> list:
|
||||
"""Split list into chunks of given size. Last chunk may be smaller."""
|
||||
if size <= 0:
|
||||
return []
|
||||
return [xs[i : i + size] for i in range(0, len(xs), size)]
|
||||
|
||||
|
||||
def take(xs: list, n: int) -> list:
|
||||
"""Take first n elements from list."""
|
||||
return xs[:n]
|
||||
|
||||
|
||||
def drop(xs: list, n: int) -> list:
|
||||
"""Drop first n elements from list."""
|
||||
return xs[n:]
|
||||
|
||||
|
||||
def unique(xs: list) -> list:
|
||||
"""Remove duplicates preserving order. Uses identity for hashable elements."""
|
||||
seen = set()
|
||||
result = []
|
||||
for x in xs:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
|
||||
def group_by(xs: list, key_fn: Callable) -> Dict:
|
||||
"""Group elements by key function. Returns dict of key -> list."""
|
||||
groups: Dict = {}
|
||||
for x in xs:
|
||||
k = key_fn(x)
|
||||
if k not in groups:
|
||||
groups[k] = []
|
||||
groups[k].append(x)
|
||||
return groups
|
||||
|
||||
|
||||
def partition(xs: list, pred: Callable) -> Tuple[list, list]:
|
||||
"""Split list into (matches, non_matches) based on predicate."""
|
||||
matches = []
|
||||
non_matches = []
|
||||
for x in xs:
|
||||
if pred(x):
|
||||
matches.append(x)
|
||||
else:
|
||||
non_matches.append(x)
|
||||
return (matches, non_matches)
|
||||
|
||||
|
||||
def find(xs: list, pred: Callable) -> Any:
|
||||
"""Find first element matching predicate. Returns None if not found."""
|
||||
for x in xs:
|
||||
if pred(x):
|
||||
return x
|
||||
return None
|
||||
|
||||
|
||||
def find_index(xs: list, pred: Callable) -> int:
|
||||
"""Find index of first element matching predicate. Returns -1 if not found."""
|
||||
for i, x in enumerate(xs):
|
||||
if pred(x):
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
def zip_with(xs: list, ys: list, fn: Callable) -> list:
|
||||
"""Zip two lists with a combining function. Stops at shorter list."""
|
||||
return [fn(x, y) for x, y in zip(xs, ys)]
|
||||
|
||||
|
||||
def all_of(xs: list, pred: Callable) -> bool:
|
||||
"""Return True if all elements match predicate."""
|
||||
return all(pred(x) for x in xs)
|
||||
|
||||
|
||||
def any_of(xs: list, pred: Callable) -> bool:
|
||||
"""Return True if any element matches predicate."""
|
||||
return any(pred(x) for x in xs)
|
||||
|
||||
|
||||
def pipe(value: Any, *fns: Callable) -> Any:
|
||||
"""Pipe a value through a sequence of functions left-to-right."""
|
||||
result = value
|
||||
for fn in fns:
|
||||
result = fn(result)
|
||||
return result
|
||||
|
||||
|
||||
def compose(*fns: Callable) -> Callable:
|
||||
"""Compose functions right-to-left. compose(f, g)(x) == f(g(x))."""
|
||||
def composed(x: Any) -> Any:
|
||||
result = x
|
||||
for fn in reversed(fns):
|
||||
result = fn(result)
|
||||
return result
|
||||
return composed
|
||||
|
||||
|
||||
# ── Tree manipulation ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def flatten_tree(structure: Any) -> List[Dict]:
|
||||
"""Flatten a hierarchical tree (dict with 'nodes') to a list without children."""
|
||||
import copy
|
||||
if isinstance(structure, dict):
|
||||
node = copy.deepcopy(structure)
|
||||
node.pop('nodes', None)
|
||||
nodes = [node]
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
nodes.extend(flatten_tree(structure[key]))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(flatten_tree(item))
|
||||
return nodes
|
||||
return []
|
||||
|
||||
|
||||
def tree_to_flat_list(structure: Any) -> List[Dict]:
|
||||
"""Convert hierarchical tree to flat list preserving DFS order (keeps internal nodes)."""
|
||||
if isinstance(structure, dict):
|
||||
nodes = [structure]
|
||||
if 'nodes' in structure:
|
||||
nodes.extend(tree_to_flat_list(structure['nodes']))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(tree_to_flat_list(item))
|
||||
return nodes
|
||||
return []
|
||||
|
||||
|
||||
def get_leaf_nodes(structure: Any) -> List[Dict]:
|
||||
"""Extract only leaf nodes (no children) from a hierarchical tree."""
|
||||
import copy
|
||||
if isinstance(structure, dict):
|
||||
if not structure.get('nodes'):
|
||||
node = copy.deepcopy(structure)
|
||||
node.pop('nodes', None)
|
||||
return [node]
|
||||
leaf_nodes = []
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
||||
return leaf_nodes
|
||||
elif isinstance(structure, list):
|
||||
leaf_nodes = []
|
||||
for item in structure:
|
||||
leaf_nodes.extend(get_leaf_nodes(item))
|
||||
return leaf_nodes
|
||||
return []
|
||||
|
||||
|
||||
def write_node_ids(data: Any, node_id: int = 0) -> int:
|
||||
"""Assign sequential zero-padded IDs (0001, 0002...) to all nodes in a tree. Returns next counter."""
|
||||
if isinstance(data, dict):
|
||||
data['node_id'] = str(node_id).zfill(4)
|
||||
node_id += 1
|
||||
for key in list(data.keys()):
|
||||
if 'nodes' in key:
|
||||
node_id = write_node_ids(data[key], node_id)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
node_id = write_node_ids(item, node_id)
|
||||
return node_id
|
||||
|
||||
|
||||
def list_to_tree(data: List[Dict]) -> List[Dict]:
|
||||
"""Convert flat list with structure codes ('1.2.3') to nested tree."""
|
||||
def get_parent_structure(structure):
|
||||
if not structure:
|
||||
return None
|
||||
parts = str(structure).split('.')
|
||||
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
||||
|
||||
nodes = {}
|
||||
root_nodes = []
|
||||
|
||||
for item in data:
|
||||
structure = item.get('structure')
|
||||
node = {
|
||||
'title': item.get('title'),
|
||||
'start_index': item.get('start_index'),
|
||||
'end_index': item.get('end_index'),
|
||||
'nodes': []
|
||||
}
|
||||
nodes[structure] = node
|
||||
parent_structure = get_parent_structure(structure)
|
||||
|
||||
if parent_structure and parent_structure in nodes:
|
||||
nodes[parent_structure]['nodes'].append(node)
|
||||
else:
|
||||
root_nodes.append(node)
|
||||
|
||||
def clean_node(node):
|
||||
if not node['nodes']:
|
||||
del node['nodes']
|
||||
else:
|
||||
for child in node['nodes']:
|
||||
clean_node(child)
|
||||
return node
|
||||
|
||||
return [clean_node(node) for node in root_nodes]
|
||||
|
||||
|
||||
def remove_tree_fields(data: Any, fields: List[str] = None) -> Any:
|
||||
"""Recursively remove specified fields from a tree (dict/list)."""
|
||||
if fields is None:
|
||||
fields = ['text']
|
||||
if isinstance(data, dict):
|
||||
return {k: remove_tree_fields(v, fields) for k, v in data.items() if k not in fields}
|
||||
elif isinstance(data, list):
|
||||
return [remove_tree_fields(item, fields) for item in data]
|
||||
return data
|
||||
|
||||
|
||||
def format_tree_structure(structure: Any, order: List[str] = None) -> Any:
|
||||
"""Reorder fields of each node in a tree according to specified key order."""
|
||||
if not order:
|
||||
return structure
|
||||
if isinstance(structure, dict):
|
||||
if 'nodes' in structure:
|
||||
structure['nodes'] = format_tree_structure(structure['nodes'], order)
|
||||
if not structure.get('nodes'):
|
||||
structure.pop('nodes', None)
|
||||
return {key: structure[key] for key in order if key in structure}
|
||||
elif isinstance(structure, list):
|
||||
return [format_tree_structure(item, order) for item in structure]
|
||||
return structure
|
||||
|
||||
|
||||
def create_node_mapping(tree: List[Dict]) -> Dict[str, Dict]:
|
||||
"""Create flat dict mapping node_id to node for O(1) lookup."""
|
||||
mapping = {}
|
||||
def _traverse(nodes):
|
||||
for node in nodes:
|
||||
if node.get('node_id'):
|
||||
mapping[node['node_id']] = node
|
||||
if node.get('nodes'):
|
||||
_traverse(node['nodes'])
|
||||
_traverse(tree)
|
||||
return mapping
|
||||
|
||||
|
||||
# ── Text / JSON extraction ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_json_from_llm(content: str) -> Dict:
|
||||
"""Extract and parse JSON from LLM responses. Handles ```json blocks, trailing commas, None->null."""
|
||||
import json
|
||||
try:
|
||||
start_idx = content.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7
|
||||
end_idx = content.rfind("```")
|
||||
json_content = content[start_idx:end_idx].strip()
|
||||
else:
|
||||
json_content = content.strip()
|
||||
|
||||
json_content = json_content.replace('None', 'null')
|
||||
json_content = json_content.replace('\n', ' ').replace('\r', ' ')
|
||||
json_content = ' '.join(json_content.split())
|
||||
|
||||
return json.loads(json_content)
|
||||
except (json.JSONDecodeError, Exception):
|
||||
try:
|
||||
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
||||
return json.loads(json_content)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def parse_page_range(pages: str) -> List[int]:
|
||||
"""Parse page range string ('5-7', '3,8', '12') into sorted list of unique ints."""
|
||||
result = []
|
||||
for part in pages.split(','):
|
||||
part = part.strip()
|
||||
if '-' in part:
|
||||
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
|
||||
if start > end:
|
||||
raise ValueError(f"Invalid range '{part}': start must be <= end")
|
||||
result.extend(range(start, end + 1))
|
||||
else:
|
||||
result.append(int(part))
|
||||
return sorted(set(result))
|
||||
|
||||
|
||||
# ── Markdown parsing ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_markdown_headers(markdown_content: str) -> Tuple[List[Dict], List[str]]:
|
||||
"""Extract all headers (h1-h6) from markdown with line numbers, skipping code blocks."""
|
||||
import re
|
||||
header_pattern = r'^(#{1,6})\s+(.+)$'
|
||||
code_block_pattern = r'^```'
|
||||
node_list = []
|
||||
lines = markdown_content.split('\n')
|
||||
in_code_block = False
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped_line = line.strip()
|
||||
if re.match(code_block_pattern, stripped_line):
|
||||
in_code_block = not in_code_block
|
||||
continue
|
||||
if not stripped_line:
|
||||
continue
|
||||
if not in_code_block:
|
||||
match = re.match(header_pattern, stripped_line)
|
||||
if match:
|
||||
level = len(match.group(1))
|
||||
title = match.group(2).strip()
|
||||
node_list.append({'title': title, 'level': level, 'line_num': line_num})
|
||||
|
||||
return node_list, lines
|
||||
|
||||
|
||||
def build_tree_from_headers(node_list: List[Dict]) -> List[Dict]:
|
||||
"""Build nested tree from flat list of headers with levels (h1>h2>h3)."""
|
||||
if not node_list:
|
||||
return []
|
||||
|
||||
stack = []
|
||||
root_nodes = []
|
||||
node_counter = 1
|
||||
|
||||
for node in node_list:
|
||||
current_level = node['level']
|
||||
tree_node = {
|
||||
'title': node['title'],
|
||||
'node_id': str(node_counter).zfill(4),
|
||||
'line_num': node['line_num'],
|
||||
'nodes': []
|
||||
}
|
||||
node_counter += 1
|
||||
|
||||
while stack and stack[-1][1] >= current_level:
|
||||
stack.pop()
|
||||
|
||||
if not stack:
|
||||
root_nodes.append(tree_node)
|
||||
else:
|
||||
parent_node, _ = stack[-1]
|
||||
parent_node['nodes'].append(tree_node)
|
||||
|
||||
stack.append((tree_node, current_level))
|
||||
|
||||
def clean_empty_nodes(nodes):
|
||||
for n in nodes:
|
||||
if n['nodes']:
|
||||
clean_empty_nodes(n['nodes'])
|
||||
else:
|
||||
del n['nodes']
|
||||
return nodes
|
||||
|
||||
return clean_empty_nodes(root_nodes)
|
||||
|
||||
|
||||
# ── Pagination / chunking ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def page_list_to_groups(page_contents: List[str], token_lengths: List[int],
|
||||
max_tokens: int = 20000, overlap_pages: int = 1) -> List[str]:
|
||||
"""Group pages into text chunks respecting token limit with configurable overlap."""
|
||||
import math
|
||||
num_tokens = sum(token_lengths)
|
||||
|
||||
if num_tokens <= max_tokens:
|
||||
return ["".join(page_contents)]
|
||||
|
||||
subsets = []
|
||||
current_subset = []
|
||||
current_token_count = 0
|
||||
|
||||
expected_parts = math.ceil(num_tokens / max_tokens)
|
||||
avg_tokens = math.ceil(((num_tokens / expected_parts) + max_tokens) / 2)
|
||||
|
||||
for i, (page_content, page_tokens) in enumerate(zip(page_contents, token_lengths)):
|
||||
if current_token_count + page_tokens > avg_tokens:
|
||||
subsets.append(''.join(current_subset))
|
||||
overlap_start = max(i - overlap_pages, 0)
|
||||
current_subset = list(page_contents[overlap_start:i])
|
||||
current_token_count = sum(token_lengths[overlap_start:i])
|
||||
|
||||
current_subset.append(page_content)
|
||||
current_token_count += page_tokens
|
||||
|
||||
if current_subset:
|
||||
subsets.append(''.join(current_subset))
|
||||
|
||||
return subsets
|
||||
|
||||
|
||||
def calculate_page_offset(pairs: List[Dict]) -> int:
|
||||
"""Calculate offset between logical page numbers and physical indices using reference pairs."""
|
||||
differences = []
|
||||
for pair in pairs:
|
||||
try:
|
||||
difference = pair['physical_index'] - pair['page']
|
||||
differences.append(difference)
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
|
||||
if not differences:
|
||||
return 0
|
||||
|
||||
counts: Dict[int, int] = {}
|
||||
for diff in differences:
|
||||
counts[diff] = counts.get(diff, 0) + 1
|
||||
|
||||
return max(counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
|
||||
# ── Text preprocessing ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""Normalize whitespace and newlines in raw text.
|
||||
|
||||
Args:
|
||||
text: Raw text to normalize.
|
||||
|
||||
Returns:
|
||||
Normalized text with consistent newlines, stripped lines, and no
|
||||
excessive blank lines.
|
||||
"""
|
||||
# Normalize line endings: \r\n and \r -> \n
|
||||
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
# Reduce 3+ consecutive newlines to at most 2
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
# Strip whitespace from each line
|
||||
text = '\n'.join(line.strip() for line in text.split('\n'))
|
||||
# Strip globally
|
||||
return text.strip()
|
||||
|
||||
|
||||
def get_text_stats(text: str) -> dict:
|
||||
"""Compute basic statistics of a text: characters, lines, words.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze.
|
||||
|
||||
Returns:
|
||||
Dict with keys total_chars (int), total_lines (int), total_words (int).
|
||||
"""
|
||||
return {
|
||||
'total_chars': len(text),
|
||||
'total_lines': text.count('\n') + 1,
|
||||
'total_words': len(text.split()),
|
||||
}
|
||||
|
||||
|
||||
# ── Git URL parsing ──────────────────────────────────────────────────────────
|
||||
|
||||
_DEFAULT_GIT_HOSTS = ["github.com", "gitlab.com"]
|
||||
|
||||
|
||||
def _sanitize_git_segment(segment: str) -> str:
|
||||
"""Strip .git suffix then keep only [a-zA-Z0-9_-] chars."""
|
||||
if segment.endswith(".git"):
|
||||
segment = segment[:-4]
|
||||
return re.sub(r"[^a-zA-Z0-9_\-]", "", segment)
|
||||
|
||||
|
||||
def parse_git_url(url: str, known_hosts: Optional[List[str]] = None) -> Optional[str]:
|
||||
"""Parse a code-hosting URL and return the 'org/repo' path component.
|
||||
|
||||
Supports HTTPS, HTTP, git://, ssh:// and SSH shorthand (git@host:path).
|
||||
Returns None if the URL does not match any known host or is malformed.
|
||||
|
||||
Args:
|
||||
url: Repository URL in any supported format.
|
||||
known_hosts: List of accepted hostnames. Defaults to github.com and gitlab.com.
|
||||
|
||||
Returns:
|
||||
'org/repo' string or None.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS
|
||||
url = url.strip()
|
||||
|
||||
if url.startswith("git@"):
|
||||
# git@github.com:org/repo.git
|
||||
rest = url[len("git@"):]
|
||||
if ":" not in rest:
|
||||
return None
|
||||
host, path = rest.split(":", 1)
|
||||
if host not in hosts:
|
||||
return None
|
||||
segments = [s for s in path.split("/") if s]
|
||||
if len(segments) < 2:
|
||||
return None
|
||||
org = _sanitize_git_segment(segments[0])
|
||||
repo = _sanitize_git_segment(segments[1])
|
||||
if not org or not repo:
|
||||
return None
|
||||
return f"{org}/{repo}"
|
||||
|
||||
for prefix in ("http://", "https://", "git://", "ssh://"):
|
||||
if url.startswith(prefix):
|
||||
parsed = urlparse(url)
|
||||
netloc = parsed.hostname or ""
|
||||
if netloc not in hosts:
|
||||
return None
|
||||
segments = [s for s in parsed.path.split("/") if s]
|
||||
if len(segments) < 2:
|
||||
return None
|
||||
org = _sanitize_git_segment(segments[0])
|
||||
repo = _sanitize_git_segment(segments[1])
|
||||
if not org or not repo:
|
||||
return None
|
||||
return f"{org}/{repo}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_git_repo_url(url: str, known_hosts: Optional[List[str]] = None) -> bool:
|
||||
"""Return True only if url points to a clonable git repository.
|
||||
|
||||
Accepts org/repo and org/repo/tree/<ref> paths.
|
||||
Rejects paths that navigate to sub-resources (issues, blobs, PRs, etc.).
|
||||
|
||||
Args:
|
||||
url: URL to verify.
|
||||
known_hosts: Accepted hostnames. Defaults to github.com and gitlab.com.
|
||||
|
||||
Returns:
|
||||
True if url is a clonable repository URL.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS
|
||||
url = url.strip()
|
||||
|
||||
# SSH shorthand — always repo-level if host matches
|
||||
if url.startswith("git@"):
|
||||
rest = url[len("git@"):]
|
||||
if ":" not in rest:
|
||||
return False
|
||||
host, _ = rest.split(":", 1)
|
||||
return host in hosts
|
||||
|
||||
# git:// and ssh:// — always repo-level if host matches
|
||||
for prefix in ("ssh://", "git://"):
|
||||
if url.startswith(prefix):
|
||||
parsed = urlparse(url)
|
||||
return (parsed.hostname or "") in hosts
|
||||
|
||||
# http:// and https:// — must have exactly org/repo or org/repo/tree/<ref>
|
||||
for prefix in ("http://", "https://"):
|
||||
if url.startswith(prefix):
|
||||
parsed = urlparse(url)
|
||||
if (parsed.hostname or "") not in hosts:
|
||||
return False
|
||||
segments = [s for s in parsed.path.split("/") if s]
|
||||
if len(segments) == 2:
|
||||
return True
|
||||
if len(segments) == 4 and segments[2] == "tree":
|
||||
return True
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def validate_git_ssh_uri(url: str) -> None:
|
||||
"""Validate a git SSH URI of the form git@host:path.
|
||||
|
||||
Raises ValueError with a descriptive message if the URI is malformed.
|
||||
|
||||
Args:
|
||||
url: URI string to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URI does not conform to git SSH format.
|
||||
"""
|
||||
if not url.startswith("git@"):
|
||||
raise ValueError(f"git SSH URI must start with 'git@', got: {url!r}")
|
||||
rest = url[len("git@"):]
|
||||
if ":" not in rest:
|
||||
raise ValueError(f"git SSH URI must contain ':', got: {url!r}")
|
||||
_, path = rest.split(":", 1)
|
||||
if not path:
|
||||
raise ValueError(f"git SSH URI must have a non-empty path after ':', got: {url!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Markdown parsing utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_frontmatter(content: str) -> Tuple[str, Optional[Dict]]:
|
||||
"""Extract YAML frontmatter delimited by '---' from the start of a markdown string.
|
||||
|
||||
Args:
|
||||
content: Raw markdown string, optionally starting with YAML frontmatter.
|
||||
|
||||
Returns:
|
||||
Tuple of (content_without_frontmatter, frontmatter_dict).
|
||||
frontmatter_dict is None when no frontmatter is found.
|
||||
"""
|
||||
pattern = re.compile(r'^---\n(.*?)\n---\n', re.DOTALL)
|
||||
match = pattern.match(content)
|
||||
if not match:
|
||||
return content, None
|
||||
|
||||
raw = match.group(1)
|
||||
remaining = content[match.end():]
|
||||
|
||||
try:
|
||||
import yaml # type: ignore
|
||||
data = yaml.safe_load(raw)
|
||||
if not isinstance(data, dict):
|
||||
data = None
|
||||
except Exception:
|
||||
# Fallback: simple key: value parser (no yaml dependency)
|
||||
data = {}
|
||||
for line in raw.splitlines():
|
||||
if ':' in line:
|
||||
key, _, value = line.partition(':')
|
||||
data[key.strip()] = value.strip()
|
||||
|
||||
return remaining, data
|
||||
|
||||
|
||||
def find_headings(content: str) -> List[Tuple[int, int, str, int]]:
|
||||
"""Find all markdown headings (# to ######), excluding those inside code blocks,
|
||||
HTML comments, and indented blocks.
|
||||
|
||||
Args:
|
||||
content: Markdown text to search.
|
||||
|
||||
Returns:
|
||||
List of (start_pos, end_pos, title, level) for each heading found.
|
||||
"""
|
||||
excluded: List[Tuple[int, int]] = []
|
||||
|
||||
# Code blocks (triple backtick)
|
||||
for m in re.finditer(r'```.*?```', content, re.DOTALL):
|
||||
excluded.append((m.start(), m.end()))
|
||||
|
||||
# HTML comments
|
||||
for m in re.finditer(r'<!--.*?-->', content, re.DOTALL):
|
||||
excluded.append((m.start(), m.end()))
|
||||
|
||||
# Indented blocks (lines starting with 4 spaces or a tab)
|
||||
for m in re.finditer(r'^( |\t).+$', content, re.MULTILINE):
|
||||
excluded.append((m.start(), m.end()))
|
||||
|
||||
def is_excluded(pos: int) -> bool:
|
||||
return any(start <= pos < end for start, end in excluded)
|
||||
|
||||
results: List[Tuple[int, int, str, int]] = []
|
||||
for m in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
# Skip escaped headings (\#)
|
||||
before = content[m.start() - 1] if m.start() > 0 else ''
|
||||
if before == '\\':
|
||||
continue
|
||||
if is_excluded(m.start()):
|
||||
continue
|
||||
level = len(m.group(1))
|
||||
title = m.group(2).strip()
|
||||
results.append((m.start(), m.end(), title, level))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def estimate_token_count(content: str) -> int:
|
||||
"""Estimate token count without a tokenizer.
|
||||
|
||||
CJK characters count as ~0.7 tokens each; other non-whitespace characters
|
||||
count as ~0.3 tokens each.
|
||||
|
||||
Args:
|
||||
content: Text to estimate.
|
||||
|
||||
Returns:
|
||||
Estimated integer token count.
|
||||
"""
|
||||
cjk = re.findall(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', content)
|
||||
without_cjk = re.sub(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', '', content)
|
||||
others = re.findall(r'\S', without_cjk)
|
||||
return int(len(cjk) * 0.7 + len(others) * 0.3)
|
||||
|
||||
|
||||
def smart_split_content(
|
||||
content: str,
|
||||
max_tokens: int = 1024,
|
||||
max_chars: int = 8000,
|
||||
) -> List[str]:
|
||||
"""Split large content into parts respecting token and character limits.
|
||||
|
||||
Splits by paragraphs (double newline). If a single paragraph exceeds the
|
||||
limit it is force-cut into chunks of max_chars.
|
||||
|
||||
Args:
|
||||
content: Text to split.
|
||||
max_tokens: Maximum estimated tokens per part.
|
||||
max_chars: Maximum characters per part.
|
||||
|
||||
Returns:
|
||||
List of string parts.
|
||||
"""
|
||||
paragraphs = content.split('\n\n')
|
||||
parts: List[str] = []
|
||||
current_parts: List[str] = []
|
||||
current_tokens = 0
|
||||
current_chars = 0
|
||||
|
||||
def flush() -> None:
|
||||
if current_parts:
|
||||
parts.append('\n\n'.join(current_parts))
|
||||
current_parts.clear()
|
||||
|
||||
for para in paragraphs:
|
||||
para_tokens = estimate_token_count(para)
|
||||
para_chars = len(para)
|
||||
|
||||
# Single paragraph exceeds limits — force-cut it
|
||||
if para_tokens > max_tokens or para_chars > max_chars:
|
||||
flush()
|
||||
current_tokens = 0
|
||||
current_chars = 0
|
||||
for i in range(0, len(para), max_chars):
|
||||
parts.append(para[i:i + max_chars])
|
||||
continue
|
||||
|
||||
# Would exceed limits if added — flush first
|
||||
if (current_tokens + para_tokens > max_tokens or
|
||||
current_chars + para_chars > max_chars):
|
||||
flush()
|
||||
current_tokens = 0
|
||||
current_chars = 0
|
||||
|
||||
current_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
current_chars += para_chars
|
||||
|
||||
flush()
|
||||
return parts if parts else [content]
|
||||
|
||||
|
||||
def sanitize_for_path(text: str, max_length: int = 50) -> str:
|
||||
"""Convert text to a safe string for use in file paths.
|
||||
|
||||
Keeps word characters, CJK characters, spaces and hyphens. Replaces spaces
|
||||
with underscores. Truncates with a sha256 suffix if the result exceeds
|
||||
max_length.
|
||||
|
||||
Args:
|
||||
text: Input text to sanitize.
|
||||
max_length: Maximum length of the returned string.
|
||||
|
||||
Returns:
|
||||
Safe path-friendly string.
|
||||
"""
|
||||
cleaned = re.sub(
|
||||
r'[^\w\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af \-]',
|
||||
'',
|
||||
text,
|
||||
)
|
||||
cleaned = cleaned.replace(' ', '_').strip('_')
|
||||
|
||||
if not cleaned:
|
||||
return 'section'
|
||||
|
||||
if len(cleaned) <= max_length:
|
||||
return cleaned
|
||||
|
||||
suffix = '_' + hashlib.sha256(text.encode()).hexdigest()[:8]
|
||||
return cleaned[:max_length - len(suffix)] + suffix
|
||||
Reference in New Issue
Block a user