815 lines
26 KiB
Python
815 lines
26 KiB
Python
"""Core functional programming utilities — pure functions for list/collection operations."""
|
|
|
|
import hashlib
|
|
import re
|
|
from functools import reduce as _reduce
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
|
|
def filter_list(xs: list, pred: Callable) -> list:
|
|
"""Filter list by predicate. Does not mutate the original."""
|
|
return [x for x in xs if pred(x)]
|
|
|
|
|
|
def map_list(xs: list, fn: Callable) -> list:
|
|
"""Map function over list. Does not mutate the original."""
|
|
return [fn(x) for x in xs]
|
|
|
|
|
|
def reduce_list(xs: list, initial: Any, fn: Callable) -> Any:
|
|
"""Reduce list with accumulator. fn(acc, x) -> acc."""
|
|
return _reduce(fn, xs, initial)
|
|
|
|
|
|
def flat_map(xs: list, fn: Callable) -> list:
|
|
"""Map function over list then flatten one level."""
|
|
result = []
|
|
for x in xs:
|
|
result.extend(fn(x))
|
|
return result
|
|
|
|
|
|
def flatten(xss: list) -> list:
|
|
"""Flatten a list of lists one level."""
|
|
result = []
|
|
for xs in xss:
|
|
result.extend(xs)
|
|
return result
|
|
|
|
|
|
def chunk(xs: list, size: int) -> list:
|
|
"""Split list into chunks of given size. Last chunk may be smaller."""
|
|
if size <= 0:
|
|
return []
|
|
return [xs[i : i + size] for i in range(0, len(xs), size)]
|
|
|
|
|
|
def take(xs: list, n: int) -> list:
|
|
"""Take first n elements from list."""
|
|
return xs[:n]
|
|
|
|
|
|
def drop(xs: list, n: int) -> list:
|
|
"""Drop first n elements from list."""
|
|
return xs[n:]
|
|
|
|
|
|
def unique(xs: list) -> list:
|
|
"""Remove duplicates preserving order. Uses identity for hashable elements."""
|
|
seen = set()
|
|
result = []
|
|
for x in xs:
|
|
if x not in seen:
|
|
seen.add(x)
|
|
result.append(x)
|
|
return result
|
|
|
|
|
|
def group_by(xs: list, key_fn: Callable) -> Dict:
|
|
"""Group elements by key function. Returns dict of key -> list."""
|
|
groups: Dict = {}
|
|
for x in xs:
|
|
k = key_fn(x)
|
|
if k not in groups:
|
|
groups[k] = []
|
|
groups[k].append(x)
|
|
return groups
|
|
|
|
|
|
def partition(xs: list, pred: Callable) -> Tuple[list, list]:
|
|
"""Split list into (matches, non_matches) based on predicate."""
|
|
matches = []
|
|
non_matches = []
|
|
for x in xs:
|
|
if pred(x):
|
|
matches.append(x)
|
|
else:
|
|
non_matches.append(x)
|
|
return (matches, non_matches)
|
|
|
|
|
|
def find(xs: list, pred: Callable) -> Any:
|
|
"""Find first element matching predicate. Returns None if not found."""
|
|
for x in xs:
|
|
if pred(x):
|
|
return x
|
|
return None
|
|
|
|
|
|
def find_index(xs: list, pred: Callable) -> int:
|
|
"""Find index of first element matching predicate. Returns -1 if not found."""
|
|
for i, x in enumerate(xs):
|
|
if pred(x):
|
|
return i
|
|
return -1
|
|
|
|
|
|
def zip_with(xs: list, ys: list, fn: Callable) -> list:
|
|
"""Zip two lists with a combining function. Stops at shorter list."""
|
|
return [fn(x, y) for x, y in zip(xs, ys)]
|
|
|
|
|
|
def all_of(xs: list, pred: Callable) -> bool:
|
|
"""Return True if all elements match predicate."""
|
|
return all(pred(x) for x in xs)
|
|
|
|
|
|
def any_of(xs: list, pred: Callable) -> bool:
|
|
"""Return True if any element matches predicate."""
|
|
return any(pred(x) for x in xs)
|
|
|
|
|
|
def pipe(value: Any, *fns: Callable) -> Any:
|
|
"""Pipe a value through a sequence of functions left-to-right."""
|
|
result = value
|
|
for fn in fns:
|
|
result = fn(result)
|
|
return result
|
|
|
|
|
|
def compose(*fns: Callable) -> Callable:
|
|
"""Compose functions right-to-left. compose(f, g)(x) == f(g(x))."""
|
|
def composed(x: Any) -> Any:
|
|
result = x
|
|
for fn in reversed(fns):
|
|
result = fn(result)
|
|
return result
|
|
return composed
|
|
|
|
|
|
# ── Tree manipulation ────────────────────────────────────────────────────────
|
|
|
|
|
|
def flatten_tree(structure: Any) -> List[Dict]:
|
|
"""Flatten a hierarchical tree (dict with 'nodes') to a list without children."""
|
|
import copy
|
|
if isinstance(structure, dict):
|
|
node = copy.deepcopy(structure)
|
|
node.pop('nodes', None)
|
|
nodes = [node]
|
|
for key in list(structure.keys()):
|
|
if 'nodes' in key:
|
|
nodes.extend(flatten_tree(structure[key]))
|
|
return nodes
|
|
elif isinstance(structure, list):
|
|
nodes = []
|
|
for item in structure:
|
|
nodes.extend(flatten_tree(item))
|
|
return nodes
|
|
return []
|
|
|
|
|
|
def tree_to_flat_list(structure: Any) -> List[Dict]:
|
|
"""Convert hierarchical tree to flat list preserving DFS order (keeps internal nodes)."""
|
|
if isinstance(structure, dict):
|
|
nodes = [structure]
|
|
if 'nodes' in structure:
|
|
nodes.extend(tree_to_flat_list(structure['nodes']))
|
|
return nodes
|
|
elif isinstance(structure, list):
|
|
nodes = []
|
|
for item in structure:
|
|
nodes.extend(tree_to_flat_list(item))
|
|
return nodes
|
|
return []
|
|
|
|
|
|
def get_leaf_nodes(structure: Any) -> List[Dict]:
|
|
"""Extract only leaf nodes (no children) from a hierarchical tree."""
|
|
import copy
|
|
if isinstance(structure, dict):
|
|
if not structure.get('nodes'):
|
|
node = copy.deepcopy(structure)
|
|
node.pop('nodes', None)
|
|
return [node]
|
|
leaf_nodes = []
|
|
for key in list(structure.keys()):
|
|
if 'nodes' in key:
|
|
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
|
return leaf_nodes
|
|
elif isinstance(structure, list):
|
|
leaf_nodes = []
|
|
for item in structure:
|
|
leaf_nodes.extend(get_leaf_nodes(item))
|
|
return leaf_nodes
|
|
return []
|
|
|
|
|
|
def write_node_ids(data: Any, node_id: int = 0) -> int:
|
|
"""Assign sequential zero-padded IDs (0001, 0002...) to all nodes in a tree. Returns next counter."""
|
|
if isinstance(data, dict):
|
|
data['node_id'] = str(node_id).zfill(4)
|
|
node_id += 1
|
|
for key in list(data.keys()):
|
|
if 'nodes' in key:
|
|
node_id = write_node_ids(data[key], node_id)
|
|
elif isinstance(data, list):
|
|
for item in data:
|
|
node_id = write_node_ids(item, node_id)
|
|
return node_id
|
|
|
|
|
|
def list_to_tree(data: List[Dict]) -> List[Dict]:
|
|
"""Convert flat list with structure codes ('1.2.3') to nested tree."""
|
|
def get_parent_structure(structure):
|
|
if not structure:
|
|
return None
|
|
parts = str(structure).split('.')
|
|
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
|
|
|
nodes = {}
|
|
root_nodes = []
|
|
|
|
for item in data:
|
|
structure = item.get('structure')
|
|
node = {
|
|
'title': item.get('title'),
|
|
'start_index': item.get('start_index'),
|
|
'end_index': item.get('end_index'),
|
|
'nodes': []
|
|
}
|
|
nodes[structure] = node
|
|
parent_structure = get_parent_structure(structure)
|
|
|
|
if parent_structure and parent_structure in nodes:
|
|
nodes[parent_structure]['nodes'].append(node)
|
|
else:
|
|
root_nodes.append(node)
|
|
|
|
def clean_node(node):
|
|
if not node['nodes']:
|
|
del node['nodes']
|
|
else:
|
|
for child in node['nodes']:
|
|
clean_node(child)
|
|
return node
|
|
|
|
return [clean_node(node) for node in root_nodes]
|
|
|
|
|
|
def remove_tree_fields(data: Any, fields: List[str] = None) -> Any:
|
|
"""Recursively remove specified fields from a tree (dict/list)."""
|
|
if fields is None:
|
|
fields = ['text']
|
|
if isinstance(data, dict):
|
|
return {k: remove_tree_fields(v, fields) for k, v in data.items() if k not in fields}
|
|
elif isinstance(data, list):
|
|
return [remove_tree_fields(item, fields) for item in data]
|
|
return data
|
|
|
|
|
|
def format_tree_structure(structure: Any, order: List[str] = None) -> Any:
|
|
"""Reorder fields of each node in a tree according to specified key order."""
|
|
if not order:
|
|
return structure
|
|
if isinstance(structure, dict):
|
|
if 'nodes' in structure:
|
|
structure['nodes'] = format_tree_structure(structure['nodes'], order)
|
|
if not structure.get('nodes'):
|
|
structure.pop('nodes', None)
|
|
return {key: structure[key] for key in order if key in structure}
|
|
elif isinstance(structure, list):
|
|
return [format_tree_structure(item, order) for item in structure]
|
|
return structure
|
|
|
|
|
|
def create_node_mapping(tree: List[Dict]) -> Dict[str, Dict]:
|
|
"""Create flat dict mapping node_id to node for O(1) lookup."""
|
|
mapping = {}
|
|
def _traverse(nodes):
|
|
for node in nodes:
|
|
if node.get('node_id'):
|
|
mapping[node['node_id']] = node
|
|
if node.get('nodes'):
|
|
_traverse(node['nodes'])
|
|
_traverse(tree)
|
|
return mapping
|
|
|
|
|
|
# ── Text / JSON extraction ───────────────────────────────────────────────────
|
|
|
|
|
|
def extract_json_from_llm(content: str) -> Dict:
|
|
"""Extract and parse JSON from LLM responses. Handles ```json blocks, trailing commas, None->null."""
|
|
import json
|
|
try:
|
|
start_idx = content.find("```json")
|
|
if start_idx != -1:
|
|
start_idx += 7
|
|
end_idx = content.rfind("```")
|
|
json_content = content[start_idx:end_idx].strip()
|
|
else:
|
|
json_content = content.strip()
|
|
|
|
json_content = json_content.replace('None', 'null')
|
|
json_content = json_content.replace('\n', ' ').replace('\r', ' ')
|
|
json_content = ' '.join(json_content.split())
|
|
|
|
return json.loads(json_content)
|
|
except (json.JSONDecodeError, Exception):
|
|
try:
|
|
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
|
return json.loads(json_content)
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def parse_page_range(pages: str) -> List[int]:
|
|
"""Parse page range string ('5-7', '3,8', '12') into sorted list of unique ints."""
|
|
result = []
|
|
for part in pages.split(','):
|
|
part = part.strip()
|
|
if '-' in part:
|
|
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
|
|
if start > end:
|
|
raise ValueError(f"Invalid range '{part}': start must be <= end")
|
|
result.extend(range(start, end + 1))
|
|
else:
|
|
result.append(int(part))
|
|
return sorted(set(result))
|
|
|
|
|
|
# ── Markdown parsing ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def extract_markdown_headers(markdown_content: str) -> Tuple[List[Dict], List[str]]:
|
|
"""Extract all headers (h1-h6) from markdown with line numbers, skipping code blocks."""
|
|
import re
|
|
header_pattern = r'^(#{1,6})\s+(.+)$'
|
|
code_block_pattern = r'^```'
|
|
node_list = []
|
|
lines = markdown_content.split('\n')
|
|
in_code_block = False
|
|
|
|
for line_num, line in enumerate(lines, 1):
|
|
stripped_line = line.strip()
|
|
if re.match(code_block_pattern, stripped_line):
|
|
in_code_block = not in_code_block
|
|
continue
|
|
if not stripped_line:
|
|
continue
|
|
if not in_code_block:
|
|
match = re.match(header_pattern, stripped_line)
|
|
if match:
|
|
level = len(match.group(1))
|
|
title = match.group(2).strip()
|
|
node_list.append({'title': title, 'level': level, 'line_num': line_num})
|
|
|
|
return node_list, lines
|
|
|
|
|
|
def build_tree_from_headers(node_list: List[Dict]) -> List[Dict]:
|
|
"""Build nested tree from flat list of headers with levels (h1>h2>h3)."""
|
|
if not node_list:
|
|
return []
|
|
|
|
stack = []
|
|
root_nodes = []
|
|
node_counter = 1
|
|
|
|
for node in node_list:
|
|
current_level = node['level']
|
|
tree_node = {
|
|
'title': node['title'],
|
|
'node_id': str(node_counter).zfill(4),
|
|
'line_num': node['line_num'],
|
|
'nodes': []
|
|
}
|
|
node_counter += 1
|
|
|
|
while stack and stack[-1][1] >= current_level:
|
|
stack.pop()
|
|
|
|
if not stack:
|
|
root_nodes.append(tree_node)
|
|
else:
|
|
parent_node, _ = stack[-1]
|
|
parent_node['nodes'].append(tree_node)
|
|
|
|
stack.append((tree_node, current_level))
|
|
|
|
def clean_empty_nodes(nodes):
|
|
for n in nodes:
|
|
if n['nodes']:
|
|
clean_empty_nodes(n['nodes'])
|
|
else:
|
|
del n['nodes']
|
|
return nodes
|
|
|
|
return clean_empty_nodes(root_nodes)
|
|
|
|
|
|
# ── Pagination / chunking ────────────────────────────────────────────────────
|
|
|
|
|
|
def page_list_to_groups(page_contents: List[str], token_lengths: List[int],
|
|
max_tokens: int = 20000, overlap_pages: int = 1) -> List[str]:
|
|
"""Group pages into text chunks respecting token limit with configurable overlap."""
|
|
import math
|
|
num_tokens = sum(token_lengths)
|
|
|
|
if num_tokens <= max_tokens:
|
|
return ["".join(page_contents)]
|
|
|
|
subsets = []
|
|
current_subset = []
|
|
current_token_count = 0
|
|
|
|
expected_parts = math.ceil(num_tokens / max_tokens)
|
|
avg_tokens = math.ceil(((num_tokens / expected_parts) + max_tokens) / 2)
|
|
|
|
for i, (page_content, page_tokens) in enumerate(zip(page_contents, token_lengths)):
|
|
if current_token_count + page_tokens > avg_tokens:
|
|
subsets.append(''.join(current_subset))
|
|
overlap_start = max(i - overlap_pages, 0)
|
|
current_subset = list(page_contents[overlap_start:i])
|
|
current_token_count = sum(token_lengths[overlap_start:i])
|
|
|
|
current_subset.append(page_content)
|
|
current_token_count += page_tokens
|
|
|
|
if current_subset:
|
|
subsets.append(''.join(current_subset))
|
|
|
|
return subsets
|
|
|
|
|
|
def calculate_page_offset(pairs: List[Dict]) -> int:
|
|
"""Calculate offset between logical page numbers and physical indices using reference pairs."""
|
|
differences = []
|
|
for pair in pairs:
|
|
try:
|
|
difference = pair['physical_index'] - pair['page']
|
|
differences.append(difference)
|
|
except (KeyError, TypeError):
|
|
continue
|
|
|
|
if not differences:
|
|
return 0
|
|
|
|
counts: Dict[int, int] = {}
|
|
for diff in differences:
|
|
counts[diff] = counts.get(diff, 0) + 1
|
|
|
|
return max(counts.items(), key=lambda x: x[1])[0]
|
|
|
|
|
|
# ── Text preprocessing ───────────────────────────────────────────────────────
|
|
|
|
|
|
def preprocess_text(text: str) -> str:
|
|
"""Normalize whitespace and newlines in raw text.
|
|
|
|
Args:
|
|
text: Raw text to normalize.
|
|
|
|
Returns:
|
|
Normalized text with consistent newlines, stripped lines, and no
|
|
excessive blank lines.
|
|
"""
|
|
# Normalize line endings: \r\n and \r -> \n
|
|
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
|
# Reduce 3+ consecutive newlines to at most 2
|
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
|
# Strip whitespace from each line
|
|
text = '\n'.join(line.strip() for line in text.split('\n'))
|
|
# Strip globally
|
|
return text.strip()
|
|
|
|
|
|
def get_text_stats(text: str) -> dict:
|
|
"""Compute basic statistics of a text: characters, lines, words.
|
|
|
|
Args:
|
|
text: Input text to analyze.
|
|
|
|
Returns:
|
|
Dict with keys total_chars (int), total_lines (int), total_words (int).
|
|
"""
|
|
return {
|
|
'total_chars': len(text),
|
|
'total_lines': text.count('\n') + 1,
|
|
'total_words': len(text.split()),
|
|
}
|
|
|
|
|
|
# ── Git URL parsing ──────────────────────────────────────────────────────────
|
|
|
|
_DEFAULT_GIT_HOSTS = ["github.com", "gitlab.com"]
|
|
|
|
|
|
def _sanitize_git_segment(segment: str) -> str:
|
|
"""Strip .git suffix then keep only [a-zA-Z0-9_-] chars."""
|
|
if segment.endswith(".git"):
|
|
segment = segment[:-4]
|
|
return re.sub(r"[^a-zA-Z0-9_\-]", "", segment)
|
|
|
|
|
|
def parse_git_url(url: str, known_hosts: Optional[List[str]] = None) -> Optional[str]:
|
|
"""Parse a code-hosting URL and return the 'org/repo' path component.
|
|
|
|
Supports HTTPS, HTTP, git://, ssh:// and SSH shorthand (git@host:path).
|
|
Returns None if the URL does not match any known host or is malformed.
|
|
|
|
Args:
|
|
url: Repository URL in any supported format.
|
|
known_hosts: List of accepted hostnames. Defaults to github.com and gitlab.com.
|
|
|
|
Returns:
|
|
'org/repo' string or None.
|
|
"""
|
|
from urllib.parse import urlparse
|
|
|
|
hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS
|
|
url = url.strip()
|
|
|
|
if url.startswith("git@"):
|
|
# git@github.com:org/repo.git
|
|
rest = url[len("git@"):]
|
|
if ":" not in rest:
|
|
return None
|
|
host, path = rest.split(":", 1)
|
|
if host not in hosts:
|
|
return None
|
|
segments = [s for s in path.split("/") if s]
|
|
if len(segments) < 2:
|
|
return None
|
|
org = _sanitize_git_segment(segments[0])
|
|
repo = _sanitize_git_segment(segments[1])
|
|
if not org or not repo:
|
|
return None
|
|
return f"{org}/{repo}"
|
|
|
|
for prefix in ("http://", "https://", "git://", "ssh://"):
|
|
if url.startswith(prefix):
|
|
parsed = urlparse(url)
|
|
netloc = parsed.hostname or ""
|
|
if netloc not in hosts:
|
|
return None
|
|
segments = [s for s in parsed.path.split("/") if s]
|
|
if len(segments) < 2:
|
|
return None
|
|
org = _sanitize_git_segment(segments[0])
|
|
repo = _sanitize_git_segment(segments[1])
|
|
if not org or not repo:
|
|
return None
|
|
return f"{org}/{repo}"
|
|
|
|
return None
|
|
|
|
|
|
def is_git_repo_url(url: str, known_hosts: Optional[List[str]] = None) -> bool:
|
|
"""Return True only if url points to a clonable git repository.
|
|
|
|
Accepts org/repo and org/repo/tree/<ref> paths.
|
|
Rejects paths that navigate to sub-resources (issues, blobs, PRs, etc.).
|
|
|
|
Args:
|
|
url: URL to verify.
|
|
known_hosts: Accepted hostnames. Defaults to github.com and gitlab.com.
|
|
|
|
Returns:
|
|
True if url is a clonable repository URL.
|
|
"""
|
|
from urllib.parse import urlparse
|
|
|
|
hosts = known_hosts if known_hosts is not None else _DEFAULT_GIT_HOSTS
|
|
url = url.strip()
|
|
|
|
# SSH shorthand — always repo-level if host matches
|
|
if url.startswith("git@"):
|
|
rest = url[len("git@"):]
|
|
if ":" not in rest:
|
|
return False
|
|
host, _ = rest.split(":", 1)
|
|
return host in hosts
|
|
|
|
# git:// and ssh:// — always repo-level if host matches
|
|
for prefix in ("ssh://", "git://"):
|
|
if url.startswith(prefix):
|
|
parsed = urlparse(url)
|
|
return (parsed.hostname or "") in hosts
|
|
|
|
# http:// and https:// — must have exactly org/repo or org/repo/tree/<ref>
|
|
for prefix in ("http://", "https://"):
|
|
if url.startswith(prefix):
|
|
parsed = urlparse(url)
|
|
if (parsed.hostname or "") not in hosts:
|
|
return False
|
|
segments = [s for s in parsed.path.split("/") if s]
|
|
if len(segments) == 2:
|
|
return True
|
|
if len(segments) == 4 and segments[2] == "tree":
|
|
return True
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def validate_git_ssh_uri(url: str) -> None:
|
|
"""Validate a git SSH URI of the form git@host:path.
|
|
|
|
Raises ValueError with a descriptive message if the URI is malformed.
|
|
|
|
Args:
|
|
url: URI string to validate.
|
|
|
|
Raises:
|
|
ValueError: If the URI does not conform to git SSH format.
|
|
"""
|
|
if not url.startswith("git@"):
|
|
raise ValueError(f"git SSH URI must start with 'git@', got: {url!r}")
|
|
rest = url[len("git@"):]
|
|
if ":" not in rest:
|
|
raise ValueError(f"git SSH URI must contain ':', got: {url!r}")
|
|
_, path = rest.split(":", 1)
|
|
if not path:
|
|
raise ValueError(f"git SSH URI must have a non-empty path after ':', got: {url!r}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Markdown parsing utilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def extract_frontmatter(content: str) -> Tuple[str, Optional[Dict]]:
|
|
"""Extract YAML frontmatter delimited by '---' from the start of a markdown string.
|
|
|
|
Args:
|
|
content: Raw markdown string, optionally starting with YAML frontmatter.
|
|
|
|
Returns:
|
|
Tuple of (content_without_frontmatter, frontmatter_dict).
|
|
frontmatter_dict is None when no frontmatter is found.
|
|
"""
|
|
pattern = re.compile(r'^---\n(.*?)\n---\n', re.DOTALL)
|
|
match = pattern.match(content)
|
|
if not match:
|
|
return content, None
|
|
|
|
raw = match.group(1)
|
|
remaining = content[match.end():]
|
|
|
|
try:
|
|
import yaml # type: ignore
|
|
data = yaml.safe_load(raw)
|
|
if not isinstance(data, dict):
|
|
data = None
|
|
except Exception:
|
|
# Fallback: simple key: value parser (no yaml dependency)
|
|
data = {}
|
|
for line in raw.splitlines():
|
|
if ':' in line:
|
|
key, _, value = line.partition(':')
|
|
data[key.strip()] = value.strip()
|
|
|
|
return remaining, data
|
|
|
|
|
|
def find_headings(content: str) -> List[Tuple[int, int, str, int]]:
|
|
"""Find all markdown headings (# to ######), excluding those inside code blocks,
|
|
HTML comments, and indented blocks.
|
|
|
|
Args:
|
|
content: Markdown text to search.
|
|
|
|
Returns:
|
|
List of (start_pos, end_pos, title, level) for each heading found.
|
|
"""
|
|
excluded: List[Tuple[int, int]] = []
|
|
|
|
# Code blocks (triple backtick)
|
|
for m in re.finditer(r'```.*?```', content, re.DOTALL):
|
|
excluded.append((m.start(), m.end()))
|
|
|
|
# HTML comments
|
|
for m in re.finditer(r'<!--.*?-->', content, re.DOTALL):
|
|
excluded.append((m.start(), m.end()))
|
|
|
|
# Indented blocks (lines starting with 4 spaces or a tab)
|
|
for m in re.finditer(r'^( |\t).+$', content, re.MULTILINE):
|
|
excluded.append((m.start(), m.end()))
|
|
|
|
def is_excluded(pos: int) -> bool:
|
|
return any(start <= pos < end for start, end in excluded)
|
|
|
|
results: List[Tuple[int, int, str, int]] = []
|
|
for m in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
|
# Skip escaped headings (\#)
|
|
before = content[m.start() - 1] if m.start() > 0 else ''
|
|
if before == '\\':
|
|
continue
|
|
if is_excluded(m.start()):
|
|
continue
|
|
level = len(m.group(1))
|
|
title = m.group(2).strip()
|
|
results.append((m.start(), m.end(), title, level))
|
|
|
|
return results
|
|
|
|
|
|
def estimate_token_count(content: str) -> int:
|
|
"""Estimate token count without a tokenizer.
|
|
|
|
CJK characters count as ~0.7 tokens each; other non-whitespace characters
|
|
count as ~0.3 tokens each.
|
|
|
|
Args:
|
|
content: Text to estimate.
|
|
|
|
Returns:
|
|
Estimated integer token count.
|
|
"""
|
|
cjk = re.findall(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', content)
|
|
without_cjk = re.sub(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]', '', content)
|
|
others = re.findall(r'\S', without_cjk)
|
|
return int(len(cjk) * 0.7 + len(others) * 0.3)
|
|
|
|
|
|
def smart_split_content(
|
|
content: str,
|
|
max_tokens: int = 1024,
|
|
max_chars: int = 8000,
|
|
) -> List[str]:
|
|
"""Split large content into parts respecting token and character limits.
|
|
|
|
Splits by paragraphs (double newline). If a single paragraph exceeds the
|
|
limit it is force-cut into chunks of max_chars.
|
|
|
|
Args:
|
|
content: Text to split.
|
|
max_tokens: Maximum estimated tokens per part.
|
|
max_chars: Maximum characters per part.
|
|
|
|
Returns:
|
|
List of string parts.
|
|
"""
|
|
paragraphs = content.split('\n\n')
|
|
parts: List[str] = []
|
|
current_parts: List[str] = []
|
|
current_tokens = 0
|
|
current_chars = 0
|
|
|
|
def flush() -> None:
|
|
if current_parts:
|
|
parts.append('\n\n'.join(current_parts))
|
|
current_parts.clear()
|
|
|
|
for para in paragraphs:
|
|
para_tokens = estimate_token_count(para)
|
|
para_chars = len(para)
|
|
|
|
# Single paragraph exceeds limits — force-cut it
|
|
if para_tokens > max_tokens or para_chars > max_chars:
|
|
flush()
|
|
current_tokens = 0
|
|
current_chars = 0
|
|
for i in range(0, len(para), max_chars):
|
|
parts.append(para[i:i + max_chars])
|
|
continue
|
|
|
|
# Would exceed limits if added — flush first
|
|
if (current_tokens + para_tokens > max_tokens or
|
|
current_chars + para_chars > max_chars):
|
|
flush()
|
|
current_tokens = 0
|
|
current_chars = 0
|
|
|
|
current_parts.append(para)
|
|
current_tokens += para_tokens
|
|
current_chars += para_chars
|
|
|
|
flush()
|
|
return parts if parts else [content]
|
|
|
|
|
|
def sanitize_for_path(text: str, max_length: int = 50) -> str:
|
|
"""Convert text to a safe string for use in file paths.
|
|
|
|
Keeps word characters, CJK characters, spaces and hyphens. Replaces spaces
|
|
with underscores. Truncates with a sha256 suffix if the result exceeds
|
|
max_length.
|
|
|
|
Args:
|
|
text: Input text to sanitize.
|
|
max_length: Maximum length of the returned string.
|
|
|
|
Returns:
|
|
Safe path-friendly string.
|
|
"""
|
|
cleaned = re.sub(
|
|
r'[^\w\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af \-]',
|
|
'',
|
|
text,
|
|
)
|
|
cleaned = cleaned.replace(' ', '_').strip('_')
|
|
|
|
if not cleaned:
|
|
return 'section'
|
|
|
|
if len(cleaned) <= max_length:
|
|
return cleaned
|
|
|
|
suffix = '_' + hashlib.sha256(text.encode()).hexdigest()[:8]
|
|
return cleaned[:max_length - len(suffix)] + suffix
|