#!/usr/bin/env python3 """ Kompanion ingestion runner. Reads pipeline configuration (YAML), walks source trees, chunks content, fetches embeddings, and upserts into the retrieval schema described in docs/db-ingest.md. """ from __future__ import annotations import argparse import fnmatch import hashlib import json import logging import os import time from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple import psycopg import requests import yaml from psycopg import sql # ------------------------- # Helper data structures # ------------------------- @dataclass class EmbedConfig: endpoint: str dim: int normalize: bool batch_size: int rate_limit_per_sec: Optional[float] @dataclass class ChunkingDocConfig: max_tokens: int = 700 overlap_tokens: int = 120 @dataclass class ChunkingCodeConfig: body_head_lines: int = 60 include_doc_comment: bool = True signature_first: bool = True attach_file_context: bool = True @dataclass class ChunkingConfig: docs: ChunkingDocConfig code: ChunkingCodeConfig @dataclass class DbConfig: dsn: str schema: Optional[str] items_table: str chunks_table: str embeddings_table: str @dataclass class SourceConfig: name: str root: Path include: Sequence[str] exclude: Sequence[str] framework: str version: str kind_overrides: Dict[str, str] @dataclass class PipelineConfig: embed: EmbedConfig chunking: ChunkingConfig db: DbConfig sources: List[SourceConfig] default_lang: Optional[str] def load_pipeline_config(path: Path) -> PipelineConfig: raw = yaml.safe_load(path.read_text()) embed_raw = raw["pipeline"]["embed"] embed = EmbedConfig( endpoint=embed_raw["endpoint"], dim=int(embed_raw.get("dim", 1024)), normalize=bool(embed_raw.get("normalize", True)), batch_size=int(embed_raw.get("batch_size", 64)), rate_limit_per_sec=float(embed_raw.get("rate_limit_per_sec", 0)) or None, ) docs_raw = raw["pipeline"]["chunking"].get("docs", {}) docs_cfg = ChunkingDocConfig( max_tokens=int(docs_raw.get("max_tokens", 700)), overlap_tokens=int(docs_raw.get("overlap_tokens", 120)), ) code_raw = raw["pipeline"]["chunking"].get("code", {}) code_cfg = ChunkingCodeConfig( body_head_lines=int(code_raw.get("body_head_lines", 60)), include_doc_comment=bool(code_raw.get("include_doc_comment", True)), signature_first=bool(code_raw.get("signature_first", True)), attach_file_context=bool(code_raw.get("attach_file_context", True)), ) chunking = ChunkingConfig(docs=docs_cfg, code=code_cfg) db_raw = raw["pipeline"]["db"] schema = db_raw.get("schema") db = DbConfig( dsn=db_raw["dsn"], schema=schema, items_table=db_raw["tables"]["items"], chunks_table=db_raw["tables"]["chunks"], embeddings_table=db_raw["tables"]["embeddings"], ) metadata_raw = raw["pipeline"].get("metadata", {}).get("compute", []) default_lang = None for entry in metadata_raw: if entry.get("name") == "lang" and "value" in entry: default_lang = entry["value"] sources = [] for src_raw in raw["pipeline"]["sources"]: include = src_raw.get("include", ["**"]) exclude = src_raw.get("exclude", []) overrides = {} for entry in src_raw.get("kind_overrides", []): overrides[entry["pattern"]] = entry["kind"] sources.append( SourceConfig( name=src_raw["name"], root=Path(src_raw["root"]), include=include, exclude=exclude, framework=src_raw.get("framework", ""), version=src_raw.get("version", ""), kind_overrides=overrides, ) ) return PipelineConfig( embed=embed, chunking=chunking, db=db, sources=sources, default_lang=default_lang, ) # ------------------------- # Utility functions # ------------------------- DOC_EXTENSIONS = {".md", ".rst", ".qdoc", ".qml", ".txt"} CODE_EXTENSIONS = { ".c", ".cc", ".cxx", ".cpp", ".h", ".hpp", ".hh", ".hxx", ".qml", ".mm", } def hash_text(text: str) -> str: return hashlib.sha1(text.encode("utf-8")).hexdigest() def estimate_tokens(text: str) -> int: return max(1, len(text.strip().split())) def path_matches(patterns: Sequence[str], rel_path: str) -> bool: return any(fnmatch.fnmatch(rel_path, pattern) for pattern in patterns) def detect_kind(rel_path: str, overrides: Dict[str, str]) -> str: for pattern, kind in overrides.items(): if fnmatch.fnmatch(rel_path, pattern): return kind suffix = Path(rel_path).suffix.lower() if suffix in DOC_EXTENSIONS: return "api_doc" return "code_symbol" # ------------------------- # CTags handling # ------------------------- class CtagsIndex: """Stores ctags JSON entries indexed by path.""" def __init__(self) -> None: self._by_path: Dict[str, List[dict]] = defaultdict(list) @staticmethod def _normalize(path: str) -> str: return Path(path).as_posix() def add(self, entry: dict) -> None: path = entry.get("path") if not path: return self._by_path[self._normalize(path)].append(entry) def extend_from_file(self, path: Path) -> None: with path.open("r", encoding="utf-8", errors="ignore") as handle: for line in handle: line = line.strip() if not line: continue try: entry = json.loads(line) except json.JSONDecodeError: continue self.add(entry) def for_file(self, file_path: Path, source_root: Path) -> List[dict]: rel = file_path.relative_to(source_root).as_posix() candidates = self._by_path.get(rel) if candidates: return sorted(candidates, key=lambda e: e.get("line", e.get("lineNumber", 0))) return sorted( self._by_path.get(file_path.as_posix(), []), key=lambda e: e.get("line", e.get("lineNumber", 0)), ) # ------------------------- # Chunk generators # ------------------------- def iter_doc_sections(text: str) -> Iterator[Tuple[str, str]]: """Yield (section_path, section_text) pairs based on markdown headings/code fences.""" lines = text.splitlines() heading_stack: List[Tuple[int, str]] = [] buffer: List[str] = [] section_path = "" in_code = False code_delim = "" def flush(): nonlocal buffer if buffer: section_text = "\n".join(buffer).strip() if section_text: yield_path = section_path or "/".join(h[1] for h in heading_stack) yield (yield_path, section_text) buffer = [] for line in lines: stripped = line.strip() if in_code: buffer.append(line) if stripped.startswith(code_delim): yield from flush() in_code = False code_delim = "" continue if stripped.startswith("```") or stripped.startswith("~~~"): yield from flush() in_code = True code_delim = stripped[:3] buffer = [line] continue if stripped.startswith("#"): yield from flush() level = len(stripped) - len(stripped.lstrip("#")) title = stripped[level:].strip() while heading_stack and heading_stack[-1][0] >= level: heading_stack.pop() heading_stack.append((level, title)) section_path = "/".join(h[1] for h in heading_stack) continue buffer.append(line) yield from flush() def chunk_doc_text(text: str, chunk_cfg: ChunkingDocConfig) -> Iterator[Tuple[str, str]]: if not text.strip(): return for section_path, section_text in iter_doc_sections(text): tokens = section_text.split() if not tokens: continue max_tokens = max(1, chunk_cfg.max_tokens) overlap = min(chunk_cfg.overlap_tokens, max_tokens - 1) if max_tokens > 1 else 0 step = max(1, max_tokens - overlap) for start in range(0, len(tokens), step): window = tokens[start : start + max_tokens] chunk = " ".join(window) yield section_path, chunk def extract_doc_comment(lines: List[str], start_index: int) -> List[str]: doc_lines: List[str] = [] i = start_index - 1 saw_content = False while i >= 0: raw = lines[i] stripped = raw.strip() if not stripped: if saw_content: break i -= 1 continue if stripped.startswith("//") or stripped.startswith("///") or stripped.startswith("/*") or stripped.startswith("*"): doc_lines.append(raw) saw_content = True i -= 1 continue break doc_lines.reverse() return doc_lines def chunk_code_text( path: Path, text: str, chunk_cfg: ChunkingCodeConfig, tags: Sequence[dict], source_root: Path, ) -> Iterator[Tuple[str, str]]: lines = text.splitlines() if not lines: return used_symbols: Set[str] = set() if tags: for tag in tags: line_no = tag.get("line") or tag.get("lineNumber") if not isinstance(line_no, int) or line_no <= 0 or line_no > len(lines): continue index = line_no - 1 snippet_lines: List[str] = [] if chunk_cfg.include_doc_comment: snippet_lines.extend(extract_doc_comment(lines, index)) if chunk_cfg.signature_first: snippet_lines.append(lines[index]) body_tail = lines[index + 1 : index + 1 + chunk_cfg.body_head_lines] snippet_lines.extend(body_tail) snippet = "\n".join(snippet_lines).strip() if not snippet: continue symbol_name = tag.get("name") or "" used_symbols.add(symbol_name) yield symbol_name, snippet if not tags or chunk_cfg.attach_file_context: head = "\n".join(lines[: chunk_cfg.body_head_lines]).strip() if head: symbol = "::file_head" if symbol not in used_symbols: yield symbol, head # ------------------------- # Embedding + database IO # ------------------------- class EmbedClient: def __init__(self, config: EmbedConfig): self.endpoint = config.endpoint self.batch_size = config.batch_size self.normalize = config.normalize self.dim = config.dim self.rate_limit = config.rate_limit_per_sec self._last_request_ts: float = 0.0 self._session = requests.Session() def _respect_rate_limit(self) -> None: if not self.rate_limit: return min_interval = 1.0 / self.rate_limit now = time.time() delta = now - self._last_request_ts if delta < min_interval: time.sleep(min_interval - delta) def embed(self, texts: Sequence[str]) -> List[List[float]]: if not texts: return [] self._respect_rate_limit() response = self._session.post( self.endpoint, json={"inputs": list(texts)}, timeout=120, ) response.raise_for_status() payload = response.json() if isinstance(payload, dict) and "embeddings" in payload: vectors = payload["embeddings"] else: vectors = payload normalized_vectors: List[List[float]] = [] for vec in vectors: if not isinstance(vec, (list, tuple)): raise ValueError("Embedding response contained non-list entry") normalized_vectors.append([float(x) for x in vec]) self._last_request_ts = time.time() return normalized_vectors class DatabaseWriter: def __init__(self, cfg: DbConfig): self.cfg = cfg self.conn = psycopg.connect(cfg.dsn) self.conn.autocommit = False schema = cfg.schema if schema: self.items_table = sql.Identifier(schema, cfg.items_table) self.chunks_table = sql.Identifier(schema, cfg.chunks_table) self.embeddings_table = sql.Identifier(schema, cfg.embeddings_table) else: self.items_table = sql.Identifier(cfg.items_table) self.chunks_table = sql.Identifier(cfg.chunks_table) self.embeddings_table = sql.Identifier(cfg.embeddings_table) def close(self) -> None: self.conn.close() def upsert_item( self, external_id: str, kind: str, framework: str, version: str, meta: dict, lang: Optional[str], ) -> int: with self.conn.cursor() as cur: cur.execute( sql.SQL( """ INSERT INTO {} (external_id, kind, framework, version, meta, lang) VALUES (%s,%s,%s,%s,%s,%s) ON CONFLICT (external_id) DO UPDATE SET framework = EXCLUDED.framework, version = EXCLUDED.version, meta = EXCLUDED.meta, lang = EXCLUDED.lang, updated_at = now() RETURNING id """ ).format(self.items_table), (external_id, kind, framework, version, json.dumps(meta), lang), ) row = cur.fetchone() assert row is not None return int(row[0]) def upsert_chunk( self, item_id: int, content: str, symbol: Optional[str], section_path: Optional[str], modality: str, ) -> Tuple[int, str]: digest = hash_text(content) with self.conn.cursor() as cur: cur.execute( sql.SQL( """ INSERT INTO {} (item_id, content, token_count, symbol, section_path, modality, hash) VALUES (%s,%s,%s,%s,%s,%s,%s) ON CONFLICT (hash) DO UPDATE SET item_id = EXCLUDED.item_id, content = EXCLUDED.content, token_count = EXCLUDED.token_count, symbol = EXCLUDED.symbol, section_path = EXCLUDED.section_path, modality = EXCLUDED.modality, created_at = now() RETURNING id, hash """ ).format(self.chunks_table), ( item_id, content, estimate_tokens(content), symbol, section_path, modality, digest, ), ) row = cur.fetchone() assert row is not None return int(row[0]), str(row[1]) def upsert_embedding(self, chunk_id: int, vector: Sequence[float]) -> None: with self.conn.cursor() as cur: cur.execute( sql.SQL( """ INSERT INTO {} (chunk_id, embedding) VALUES (%s,%s) ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding, created_at = now() """ ).format(self.embeddings_table), (chunk_id, vector), ) def commit(self) -> None: self.conn.commit() # ------------------------- # Ingestion runner # ------------------------- def gather_files(source: SourceConfig) -> Iterator[Tuple[Path, str, str, str]]: root = source.root if not root.exists(): logging.warning("Source root %s does not exist, skipping", root) return include_patterns = source.include or ["**"] exclude_patterns = source.exclude or [] for path in root.rglob("*"): if path.is_dir(): continue rel = path.relative_to(root).as_posix() if include_patterns and not path_matches(include_patterns, rel): continue if exclude_patterns and path_matches(exclude_patterns, rel): continue try: text = path.read_text(encoding="utf-8", errors="ignore") except Exception as exc: # noqa: BLE001 logging.debug("Failed reading %s: %s", path, exc) continue kind = detect_kind(rel, source.kind_overrides) yield path, rel, kind, text def enrich_meta(source: SourceConfig, rel: str, extra: Optional[dict] = None) -> dict: meta = { "source": source.name, "path": rel, } if extra: meta.update(extra) return meta def ingest_source( source: SourceConfig, cfg: PipelineConfig, ctags_index: CtagsIndex, embed_client: EmbedClient, db: DatabaseWriter, ) -> None: doc_cfg = cfg.chunking.docs code_cfg = cfg.chunking.code lang = cfg.default_lang batch_texts: List[str] = [] batch_chunk_ids: List[int] = [] def flush_batch() -> None: nonlocal batch_texts, batch_chunk_ids if not batch_texts: return vectors = embed_client.embed(batch_texts) if len(vectors) != len(batch_chunk_ids): raise RuntimeError("Embedding count mismatch.") for chunk_id, vector in zip(batch_chunk_ids, vectors): db.upsert_embedding(chunk_id, vector) db.commit() batch_texts = [] batch_chunk_ids = [] processed = 0 for path, rel, kind, text in gather_files(source): processed += 1 meta = enrich_meta(source, rel) item_external_id = f"repo:{source.name}:{rel}" item_id = db.upsert_item( external_id=item_external_id, kind=kind, framework=source.framework, version=source.version, meta=meta, lang=lang, ) if kind == "api_doc": for section_path, chunk_text in chunk_doc_text(text, doc_cfg): chunk_id, _ = db.upsert_chunk( item_id=item_id, content=chunk_text, symbol=None, section_path=section_path or None, modality="text", ) batch_texts.append(chunk_text) batch_chunk_ids.append(chunk_id) if len(batch_texts) >= embed_client.batch_size: flush_batch() else: tags = ctags_index.for_file(path, source.root) symbols = [] for symbol_name, chunk_text in chunk_code_text(path, text, code_cfg, tags, source.root): symbols.append(symbol_name) chunk_id, _ = db.upsert_chunk( item_id=item_id, content=chunk_text, symbol=symbol_name or None, section_path=None, modality="text", ) batch_texts.append(chunk_text) batch_chunk_ids.append(chunk_id) if len(batch_texts) >= embed_client.batch_size: flush_batch() if symbols: db.upsert_item( external_id=item_external_id, kind=kind, framework=source.framework, version=source.version, meta=enrich_meta(source, rel, {"symbols": symbols}), lang=lang, ) flush_batch() if processed: logging.info("Processed %d files from %s", processed, source.name) def run_ingest(config_path: Path, ctags_paths: Sequence[Path]) -> None: pipeline_cfg = load_pipeline_config(config_path) embed_client = EmbedClient(pipeline_cfg.embed) db_writer = DatabaseWriter(pipeline_cfg.db) ctags_index = CtagsIndex() for ctags_path in ctags_paths: if ctags_path.exists(): ctags_index.extend_from_file(ctags_path) else: logging.warning("ctags file %s missing; skipping", ctags_path) try: for source in pipeline_cfg.sources: ingest_source( source=source, cfg=pipeline_cfg, ctags_index=ctags_index, embed_client=embed_client, db=db_writer, ) finally: db_writer.commit() db_writer.close() def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="Kompanion ingestion runner") parser.add_argument("--config", required=True, type=Path, help="Pipeline YAML path") parser.add_argument( "--ctags", nargs="*", type=Path, default=[], help="Optional one or more ctags JSON files", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], ) return parser.parse_args(argv) def main(argv: Optional[Sequence[str]] = None) -> None: args = parse_args(argv) logging.basicConfig(level=getattr(logging, args.log_level), format="%(levelname)s %(message)s") run_ingest(args.config, args.ctags) if __name__ == "__main__": main()