435 lines
14 KiB
C++
435 lines
14 KiB
C++
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <cctype>
|
|
#include <cstdlib>
|
|
#include <iomanip>
|
|
#include <optional>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#include "dal/PgDal.hpp"
|
|
|
|
namespace Handlers {
|
|
namespace detail {
|
|
|
|
inline kom::PgDal& database() {
|
|
static kom::PgDal instance;
|
|
static bool connected = [] {
|
|
const char* env = std::getenv("PG_DSN");
|
|
const std::string dsn = (env && *env) ? std::string(env) : std::string();
|
|
if (!dsn.empty()) {
|
|
return instance.connect(dsn);
|
|
}
|
|
return instance.connect("stub://memory");
|
|
}();
|
|
(void)connected;
|
|
return instance;
|
|
}
|
|
|
|
inline std::string json_escape(const std::string& in) {
|
|
std::ostringstream os;
|
|
for (char c : in) {
|
|
switch (c) {
|
|
case '\"': os << "\\\""; break;
|
|
case '\\': os << "\\\\"; break;
|
|
case '\n': os << "\\n"; break;
|
|
case '\r': os << "\\r"; break;
|
|
case '\t': os << "\\t"; break;
|
|
default: os << c; break;
|
|
}
|
|
}
|
|
return os.str();
|
|
}
|
|
|
|
inline std::string error_response(const std::string& code, const std::string& message) {
|
|
std::ostringstream os;
|
|
os << "{\"error\":{\"code\":\"" << json_escape(code)
|
|
<< "\",\"message\":\"" << json_escape(message) << "\"}}";
|
|
return os.str();
|
|
}
|
|
|
|
inline std::optional<std::string> find_delimited_segment(const std::string& json,
|
|
const std::string& key,
|
|
char open,
|
|
char close) {
|
|
const std::string pattern = "\"" + key + "\"";
|
|
auto pos = json.find(pattern);
|
|
if (pos == std::string::npos) return std::nullopt;
|
|
pos = json.find(open, pos);
|
|
if (pos == std::string::npos) return std::nullopt;
|
|
|
|
int depth = 0;
|
|
bool inString = false;
|
|
bool escape = false;
|
|
std::size_t start = std::string::npos;
|
|
for (std::size_t i = pos; i < json.size(); ++i) {
|
|
char c = json[i];
|
|
if (escape) { escape = false; continue; }
|
|
if (c == '\\') {
|
|
if (inString) escape = true;
|
|
continue;
|
|
}
|
|
if (c == '\"') {
|
|
inString = !inString;
|
|
continue;
|
|
}
|
|
if (inString) continue;
|
|
|
|
if (c == open) {
|
|
if (depth == 0) start = i;
|
|
++depth;
|
|
continue;
|
|
}
|
|
if (c == close) {
|
|
--depth;
|
|
if (depth == 0 && start != std::string::npos) {
|
|
return json.substr(start, i - start + 1);
|
|
}
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
inline std::optional<std::string> find_object_segment(const std::string& json, const std::string& key) {
|
|
return find_delimited_segment(json, key, '{', '}');
|
|
}
|
|
|
|
inline std::optional<std::string> find_array_segment(const std::string& json, const std::string& key) {
|
|
return find_delimited_segment(json, key, '[', ']');
|
|
}
|
|
|
|
inline std::string extract_string_field(const std::string& json, const std::string& key) {
|
|
const std::string pattern = "\"" + key + "\"";
|
|
auto pos = json.find(pattern);
|
|
if (pos == std::string::npos) return {};
|
|
pos = json.find(':', pos);
|
|
if (pos == std::string::npos) return {};
|
|
++pos;
|
|
while (pos < json.size() && std::isspace(static_cast<unsigned char>(json[pos]))) ++pos;
|
|
if (pos >= json.size() || json[pos] != '\"') return {};
|
|
++pos;
|
|
std::ostringstream os;
|
|
bool escape = false;
|
|
for (; pos < json.size(); ++pos) {
|
|
char c = json[pos];
|
|
if (escape) {
|
|
switch (c) {
|
|
case '\"': os << '\"'; break;
|
|
case '\\': os << '\\'; break;
|
|
case '/': os << '/'; break;
|
|
case 'b': os << '\b'; break;
|
|
case 'f': os << '\f'; break;
|
|
case 'n': os << '\n'; break;
|
|
case 'r': os << '\r'; break;
|
|
case 't': os << '\t'; break;
|
|
default: os << c; break;
|
|
}
|
|
escape = false;
|
|
continue;
|
|
}
|
|
if (c == '\\') { escape = true; continue; }
|
|
if (c == '\"') break;
|
|
os << c;
|
|
}
|
|
return os.str();
|
|
}
|
|
|
|
inline std::optional<int> extract_int_field(const std::string& json, const std::string& key) {
|
|
const std::string pattern = "\"" + key + "\"";
|
|
auto pos = json.find(pattern);
|
|
if (pos == std::string::npos) return std::nullopt;
|
|
pos = json.find(':', pos);
|
|
if (pos == std::string::npos) return std::nullopt;
|
|
++pos;
|
|
while (pos < json.size() && std::isspace(static_cast<unsigned char>(json[pos]))) ++pos;
|
|
std::size_t start = pos;
|
|
while (pos < json.size() && (std::isdigit(static_cast<unsigned char>(json[pos])) || json[pos] == '-')) ++pos;
|
|
if (start == pos) return std::nullopt;
|
|
try {
|
|
return std::stoi(json.substr(start, pos - start));
|
|
} catch (...) {
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
inline std::vector<std::string> parse_object_array(const std::string& json, const std::string& key) {
|
|
std::vector<std::string> objects;
|
|
auto segment = find_array_segment(json, key);
|
|
if (!segment) return objects;
|
|
const std::string& arr = *segment;
|
|
|
|
int depth = 0;
|
|
bool inString = false;
|
|
bool escape = false;
|
|
std::size_t start = std::string::npos;
|
|
for (std::size_t i = 0; i < arr.size(); ++i) {
|
|
char c = arr[i];
|
|
if (escape) { escape = false; continue; }
|
|
if (c == '\\') {
|
|
if (inString) escape = true;
|
|
continue;
|
|
}
|
|
if (c == '\"') {
|
|
inString = !inString;
|
|
continue;
|
|
}
|
|
if (inString) continue;
|
|
|
|
if (c == '{') {
|
|
if (depth == 0) start = i;
|
|
++depth;
|
|
} else if (c == '}') {
|
|
--depth;
|
|
if (depth == 0 && start != std::string::npos) {
|
|
objects.push_back(arr.substr(start, i - start + 1));
|
|
start = std::string::npos;
|
|
}
|
|
}
|
|
}
|
|
return objects;
|
|
}
|
|
|
|
inline std::vector<std::string> parse_string_array(const std::string& json, const std::string& key) {
|
|
std::vector<std::string> values;
|
|
auto segment = find_array_segment(json, key);
|
|
if (!segment) return values;
|
|
const std::string& arr = *segment;
|
|
|
|
bool inString = false;
|
|
bool escape = false;
|
|
std::ostringstream current;
|
|
for (std::size_t i = 0; i < arr.size(); ++i) {
|
|
char c = arr[i];
|
|
if (escape) {
|
|
switch (c) {
|
|
case '\"': current << '\"'; break;
|
|
case '\\': current << '\\'; break;
|
|
case 'n': current << '\n'; break;
|
|
case 'r': current << '\r'; break;
|
|
case 't': current << '\t'; break;
|
|
default: current << c; break;
|
|
}
|
|
escape = false;
|
|
continue;
|
|
}
|
|
if (c == '\\' && inString) { escape = true; continue; }
|
|
if (c == '\"') {
|
|
if (inString) {
|
|
values.push_back(current.str());
|
|
current.str("");
|
|
current.clear();
|
|
}
|
|
inString = !inString;
|
|
continue;
|
|
}
|
|
if (inString) current << c;
|
|
}
|
|
return values;
|
|
}
|
|
|
|
inline std::vector<float> parse_float_array(const std::string& json, const std::string& key) {
|
|
std::vector<float> values;
|
|
auto segment = find_array_segment(json, key);
|
|
if (!segment) return values;
|
|
const std::string& arr = *segment;
|
|
std::size_t pos = 0;
|
|
while (pos < arr.size()) {
|
|
while (pos < arr.size() && !std::isdigit(static_cast<unsigned char>(arr[pos])) && arr[pos] != '-' && arr[pos] != '+') ++pos;
|
|
if (pos >= arr.size()) break;
|
|
std::size_t end = pos;
|
|
while (end < arr.size() && (std::isdigit(static_cast<unsigned char>(arr[end])) || arr[end] == '.' || arr[end] == 'e' || arr[end] == 'E' || arr[end] == '+' || arr[end] == '-')) ++end;
|
|
try {
|
|
values.push_back(std::stof(arr.substr(pos, end - pos)));
|
|
} catch (...) {
|
|
// Skip unparsable token.
|
|
}
|
|
pos = end;
|
|
}
|
|
return values;
|
|
}
|
|
|
|
inline std::string format_score(float score) {
|
|
std::ostringstream os;
|
|
os.setf(std::ios::fixed);
|
|
os << std::setprecision(3) << score;
|
|
return os.str();
|
|
}
|
|
|
|
struct ParsedItem {
|
|
std::string id;
|
|
std::string text;
|
|
std::vector<std::string> tags;
|
|
std::vector<float> embedding;
|
|
std::string rawJson;
|
|
};
|
|
|
|
inline std::vector<ParsedItem> parse_items(const std::string& json) {
|
|
std::vector<ParsedItem> items;
|
|
for (const auto& obj : parse_object_array(json, "items")) {
|
|
ParsedItem item;
|
|
item.rawJson = obj;
|
|
item.id = extract_string_field(obj, "id");
|
|
item.text = extract_string_field(obj, "text");
|
|
item.tags = parse_string_array(obj, "tags");
|
|
item.embedding = parse_float_array(obj, "embedding");
|
|
items.push_back(std::move(item));
|
|
}
|
|
return items;
|
|
}
|
|
|
|
struct SearchMatch {
|
|
std::string id;
|
|
float score;
|
|
std::optional<std::string> text;
|
|
};
|
|
|
|
inline std::string serialize_matches(const std::vector<SearchMatch>& matches) {
|
|
std::ostringstream os;
|
|
os << "{\"matches\":[";
|
|
for (std::size_t i = 0; i < matches.size(); ++i) {
|
|
const auto& match = matches[i];
|
|
if (i) os << ",";
|
|
os << "{\"id\":\"" << json_escape(match.id) << "\""
|
|
<< ",\"score\":" << format_score(match.score);
|
|
if (match.text) {
|
|
os << ",\"text\":\"" << json_escape(*match.text) << "\"";
|
|
}
|
|
os << "}";
|
|
}
|
|
os << "]}";
|
|
return os.str();
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
inline std::string upsert_memory(const std::string& reqJson) {
|
|
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
|
|
if (nsName.empty()) {
|
|
return detail::error_response("bad_request", "namespace is required");
|
|
}
|
|
|
|
auto nsRow = detail::database().ensureNamespace(nsName);
|
|
if (!nsRow) {
|
|
return detail::error_response("internal_error", "failed to ensure namespace");
|
|
}
|
|
|
|
auto items = detail::parse_items(reqJson);
|
|
if (items.empty()) {
|
|
return detail::error_response("bad_request", "items array must contain at least one entry");
|
|
}
|
|
|
|
kom::PgDal& dal = detail::database();
|
|
const bool hasTx = dal.begin();
|
|
std::vector<std::string> ids;
|
|
ids.reserve(items.size());
|
|
|
|
try {
|
|
for (auto& parsed : items) {
|
|
kom::ItemRow row;
|
|
row.id = parsed.id;
|
|
row.namespace_id = nsRow->id;
|
|
row.content_json = parsed.rawJson;
|
|
row.text = parsed.text.empty() ? std::optional<std::string>() : std::optional<std::string>(parsed.text);
|
|
row.tags = parsed.tags;
|
|
row.revision = 1;
|
|
|
|
const std::string itemId = dal.upsertItem(row);
|
|
ids.push_back(itemId);
|
|
|
|
if (!parsed.embedding.empty()) {
|
|
kom::ChunkRow chunk;
|
|
chunk.item_id = itemId;
|
|
chunk.ord = 0;
|
|
chunk.text = parsed.text;
|
|
auto chunkIds = dal.upsertChunks(std::vector<kom::ChunkRow>{chunk});
|
|
|
|
kom::EmbeddingRow emb;
|
|
emb.chunk_id = chunkIds.front();
|
|
emb.model = "stub-model";
|
|
emb.dim = static_cast<int>(parsed.embedding.size());
|
|
emb.vector = parsed.embedding;
|
|
dal.upsertEmbeddings(std::vector<kom::EmbeddingRow>{emb});
|
|
}
|
|
}
|
|
if (hasTx) dal.commit();
|
|
} catch (const std::exception& ex) {
|
|
if (hasTx) dal.rollback();
|
|
return detail::error_response("internal_error", ex.what());
|
|
}
|
|
|
|
std::ostringstream os;
|
|
os << "{\"upserted\":" << ids.size();
|
|
if (!ids.empty()) {
|
|
os << ",\"ids\":[";
|
|
for (std::size_t i = 0; i < ids.size(); ++i) {
|
|
if (i) os << ",";
|
|
os << "\"" << detail::json_escape(ids[i]) << "\"";
|
|
}
|
|
os << "]";
|
|
}
|
|
os << ",\"status\":\"ok\"}";
|
|
return os.str();
|
|
}
|
|
|
|
inline std::string search_memory(const std::string& reqJson) {
|
|
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
|
|
if (nsName.empty()) {
|
|
return detail::error_response("bad_request", "namespace is required");
|
|
}
|
|
|
|
auto nsRow = detail::database().findNamespace(nsName);
|
|
if (!nsRow) {
|
|
return "{\"matches\":[]}";
|
|
}
|
|
|
|
std::string queryText;
|
|
std::vector<float> queryEmbedding;
|
|
int limit = 5;
|
|
|
|
if (auto queryObj = detail::find_object_segment(reqJson, "query")) {
|
|
queryText = detail::extract_string_field(*queryObj, "text");
|
|
queryEmbedding = detail::parse_float_array(*queryObj, "embedding");
|
|
if (auto k = detail::extract_int_field(*queryObj, "k")) {
|
|
if (*k > 0) limit = *k;
|
|
}
|
|
}
|
|
|
|
kom::PgDal& dal = detail::database();
|
|
std::unordered_set<std::string> seen;
|
|
std::vector<detail::SearchMatch> matches;
|
|
|
|
auto textRows = dal.searchText(nsRow->id, queryText, limit);
|
|
for (std::size_t idx = 0; idx < textRows.size(); ++idx) {
|
|
const auto& row = textRows[idx];
|
|
detail::SearchMatch match;
|
|
match.id = row.id;
|
|
match.text = row.text;
|
|
match.score = 1.0f - static_cast<float>(idx) * 0.05f;
|
|
matches.push_back(match);
|
|
seen.insert(match.id);
|
|
if (static_cast<int>(matches.size()) >= limit) break;
|
|
}
|
|
|
|
if (static_cast<int>(matches.size()) < limit && !queryEmbedding.empty()) {
|
|
auto vectorMatches = dal.searchVector(nsRow->id, queryEmbedding, limit);
|
|
for (const auto& pair : vectorMatches) {
|
|
if (seen.count(pair.first)) continue;
|
|
auto item = dal.getItemById(pair.first);
|
|
if (!item) continue;
|
|
detail::SearchMatch match;
|
|
match.id = pair.first;
|
|
match.score = pair.second;
|
|
match.text = item->text;
|
|
matches.push_back(match);
|
|
if (static_cast<int>(matches.size()) >= limit) break;
|
|
}
|
|
}
|
|
|
|
return detail::serialize_matches(matches);
|
|
}
|
|
|
|
} // namespace Handlers
|