49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
#!/usr/bin/env python3
|
|
import os, sys, json, requests, psycopg
|
|
|
|
DB=os.environ.get("DB_URL","dbname=kompanion user=kompanion host=/var/run/postgresql")
|
|
OLLAMA=os.environ.get("OLLAMA_BASE","http://127.0.0.1:11434")
|
|
MODEL=os.environ.get("EMBED_MODEL","mxbai-embed-large")
|
|
SPACE=os.environ.get("EMBED_SPACE","dev_knowledge")
|
|
|
|
HELP="""\
|
|
Usage: pg_search.py "query text" [k]
|
|
Env: DB_URL, OLLAMA_BASE, EMBED_MODEL, EMBED_SPACE (default dev_knowledge)
|
|
Prints JSON results: [{score, uri, lineno, text}].
|
|
"""
|
|
|
|
def embed(q: str):
|
|
r = requests.post(f"{OLLAMA}/api/embeddings", json={"model": MODEL, "prompt": q}, timeout=120)
|
|
r.raise_for_status()
|
|
return r.json()["embedding"]
|
|
|
|
if __name__=="__main__":
|
|
if len(sys.argv)<2:
|
|
print(HELP, file=sys.stderr); sys.exit(1)
|
|
query = sys.argv[1]
|
|
k = int(sys.argv[2]) if len(sys.argv)>2 else 8
|
|
vec = embed(query)
|
|
with psycopg.connect(DB) as conn, conn.cursor() as cur:
|
|
cur.execute("SELECT id, dim FROM komp.space WHERE name=%s", (SPACE,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
sys.exit(f"space {SPACE} missing")
|
|
sid, dim = row
|
|
if dim not in (768,1024):
|
|
sys.exit(f"unsupported dim {dim}")
|
|
table = f"komp.embedding_{dim}"
|
|
# cosine distance with vector_cosine_ops
|
|
sql = f"""
|
|
SELECT (e.embedding <=> %(v)s::vector) AS score, s.uri, k.lineno, k.text
|
|
FROM {table} e
|
|
JOIN komp.chunk k ON k.id = e.chunk_id
|
|
JOIN komp.source s ON s.id = k.source_id
|
|
WHERE e.space_id = %(sid)s
|
|
ORDER BY e.embedding <=> %(v)s::vector
|
|
LIMIT %(k)s
|
|
"""
|
|
cur.execute(sql, {"v": vec, "sid": sid, "k": k})
|
|
out=[{"score":float(r[0]),"uri":r[1],"lineno":r[2],"text":r[3]} for r in cur.fetchall()]
|
|
print(json.dumps(out, ensure_ascii=False, indent=2))
|
|
|