#pragma once #include #include #include #include #include #include #include #include #include #include "dal/PgDal.hpp" namespace Handlers { namespace detail { inline kom::PgDal& database() { static kom::PgDal instance; static bool connected = [] { const char* env = std::getenv("PG_DSN"); const std::string dsn = (env && *env) ? std::string(env) : std::string(); if (!dsn.empty()) { return instance.connect(dsn); } return instance.connect("stub://memory"); }(); (void)connected; return instance; } inline std::string json_escape(const std::string& in) { std::ostringstream os; for (char c : in) { switch (c) { case '\"': os << "\\\""; break; case '\\': os << "\\\\"; break; case '\n': os << "\\n"; break; case '\r': os << "\\r"; break; case '\t': os << "\\t"; break; default: os << c; break; } } return os.str(); } inline std::string error_response(const std::string& code, const std::string& message) { std::ostringstream os; os << "{\"error\":{\"code\":\"" << json_escape(code) << "\",\"message\":\"" << json_escape(message) << "\"}}"; return os.str(); } inline std::optional find_delimited_segment(const std::string& json, const std::string& key, char open, char close) { const std::string pattern = "\"" + key + "\""; auto pos = json.find(pattern); if (pos == std::string::npos) return std::nullopt; pos = json.find(open, pos); if (pos == std::string::npos) return std::nullopt; int depth = 0; bool inString = false; bool escape = false; std::size_t start = std::string::npos; for (std::size_t i = pos; i < json.size(); ++i) { char c = json[i]; if (escape) { escape = false; continue; } if (c == '\\') { if (inString) escape = true; continue; } if (c == '\"') { inString = !inString; continue; } if (inString) continue; if (c == open) { if (depth == 0) start = i; ++depth; continue; } if (c == close) { --depth; if (depth == 0 && start != std::string::npos) { return json.substr(start, i - start + 1); } } } return std::nullopt; } inline std::optional find_object_segment(const std::string& json, const std::string& key) { return find_delimited_segment(json, key, '{', '}'); } inline std::optional find_array_segment(const std::string& json, const std::string& key) { return find_delimited_segment(json, key, '[', ']'); } inline std::string extract_string_field(const std::string& json, const std::string& key) { const std::string pattern = "\"" + key + "\""; auto pos = json.find(pattern); if (pos == std::string::npos) return {}; pos = json.find(':', pos); if (pos == std::string::npos) return {}; ++pos; while (pos < json.size() && std::isspace(static_cast(json[pos]))) ++pos; if (pos >= json.size() || json[pos] != '\"') return {}; ++pos; std::ostringstream os; bool escape = false; for (; pos < json.size(); ++pos) { char c = json[pos]; if (escape) { switch (c) { case '\"': os << '\"'; break; case '\\': os << '\\'; break; case '/': os << '/'; break; case 'b': os << '\b'; break; case 'f': os << '\f'; break; case 'n': os << '\n'; break; case 'r': os << '\r'; break; case 't': os << '\t'; break; default: os << c; break; } escape = false; continue; } if (c == '\\') { escape = true; continue; } if (c == '\"') break; os << c; } return os.str(); } inline std::optional extract_int_field(const std::string& json, const std::string& key) { const std::string pattern = "\"" + key + "\""; auto pos = json.find(pattern); if (pos == std::string::npos) return std::nullopt; pos = json.find(':', pos); if (pos == std::string::npos) return std::nullopt; ++pos; while (pos < json.size() && std::isspace(static_cast(json[pos]))) ++pos; std::size_t start = pos; while (pos < json.size() && (std::isdigit(static_cast(json[pos])) || json[pos] == '-')) ++pos; if (start == pos) return std::nullopt; try { return std::stoi(json.substr(start, pos - start)); } catch (...) { return std::nullopt; } } inline std::vector parse_object_array(const std::string& json, const std::string& key) { std::vector objects; auto segment = find_array_segment(json, key); if (!segment) return objects; const std::string& arr = *segment; int depth = 0; bool inString = false; bool escape = false; std::size_t start = std::string::npos; for (std::size_t i = 0; i < arr.size(); ++i) { char c = arr[i]; if (escape) { escape = false; continue; } if (c == '\\') { if (inString) escape = true; continue; } if (c == '\"') { inString = !inString; continue; } if (inString) continue; if (c == '{') { if (depth == 0) start = i; ++depth; } else if (c == '}') { --depth; if (depth == 0 && start != std::string::npos) { objects.push_back(arr.substr(start, i - start + 1)); start = std::string::npos; } } } return objects; } inline std::vector parse_string_array(const std::string& json, const std::string& key) { std::vector values; auto segment = find_array_segment(json, key); if (!segment) return values; const std::string& arr = *segment; bool inString = false; bool escape = false; std::ostringstream current; for (std::size_t i = 0; i < arr.size(); ++i) { char c = arr[i]; if (escape) { switch (c) { case '\"': current << '\"'; break; case '\\': current << '\\'; break; case 'n': current << '\n'; break; case 'r': current << '\r'; break; case 't': current << '\t'; break; default: current << c; break; } escape = false; continue; } if (c == '\\' && inString) { escape = true; continue; } if (c == '\"') { if (inString) { values.push_back(current.str()); current.str(""); current.clear(); } inString = !inString; continue; } if (inString) current << c; } return values; } inline std::vector parse_float_array(const std::string& json, const std::string& key) { std::vector values; auto segment = find_array_segment(json, key); if (!segment) return values; const std::string& arr = *segment; std::size_t pos = 0; while (pos < arr.size()) { while (pos < arr.size() && !std::isdigit(static_cast(arr[pos])) && arr[pos] != '-' && arr[pos] != '+') ++pos; if (pos >= arr.size()) break; std::size_t end = pos; while (end < arr.size() && (std::isdigit(static_cast(arr[end])) || arr[end] == '.' || arr[end] == 'e' || arr[end] == 'E' || arr[end] == '+' || arr[end] == '-')) ++end; try { values.push_back(std::stof(arr.substr(pos, end - pos))); } catch (...) { // Skip unparsable token. } pos = end; } return values; } inline std::string format_score(float score) { std::ostringstream os; os.setf(std::ios::fixed); os << std::setprecision(3) << score; return os.str(); } struct ParsedItem { std::string id; std::string text; std::vector tags; std::vector embedding; std::string rawJson; }; inline std::vector parse_items(const std::string& json) { std::vector items; for (const auto& obj : parse_object_array(json, "items")) { ParsedItem item; item.rawJson = obj; item.id = extract_string_field(obj, "id"); item.text = extract_string_field(obj, "text"); item.tags = parse_string_array(obj, "tags"); item.embedding = parse_float_array(obj, "embedding"); items.push_back(std::move(item)); } return items; } struct SearchMatch { std::string id; float score; std::optional text; }; inline std::string serialize_matches(const std::vector& matches) { std::ostringstream os; os << "{\"matches\":["; for (std::size_t i = 0; i < matches.size(); ++i) { const auto& match = matches[i]; if (i) os << ","; os << "{\"id\":\"" << json_escape(match.id) << "\"" << ",\"score\":" << format_score(match.score); if (match.text) { os << ",\"text\":\"" << json_escape(*match.text) << "\""; } os << "}"; } os << "]}"; return os.str(); } } // namespace detail inline std::string upsert_memory(const std::string& reqJson) { const std::string nsName = detail::extract_string_field(reqJson, "namespace"); if (nsName.empty()) { return detail::error_response("bad_request", "namespace is required"); } auto nsRow = detail::database().ensureNamespace(nsName); if (!nsRow) { return detail::error_response("internal_error", "failed to ensure namespace"); } auto items = detail::parse_items(reqJson); if (items.empty()) { return detail::error_response("bad_request", "items array must contain at least one entry"); } kom::PgDal& dal = detail::database(); const bool hasTx = dal.begin(); std::vector ids; ids.reserve(items.size()); try { for (auto& parsed : items) { kom::ItemRow row; row.id = parsed.id; row.namespace_id = nsRow->id; row.content_json = parsed.rawJson; row.text = parsed.text.empty() ? std::optional() : std::optional(parsed.text); row.tags = parsed.tags; row.revision = 1; const std::string itemId = dal.upsertItem(row); ids.push_back(itemId); if (!parsed.embedding.empty()) { kom::ChunkRow chunk; chunk.item_id = itemId; chunk.ord = 0; chunk.text = parsed.text; auto chunkIds = dal.upsertChunks(std::vector{chunk}); kom::EmbeddingRow emb; emb.chunk_id = chunkIds.front(); emb.model = "stub-model"; emb.dim = static_cast(parsed.embedding.size()); emb.vector = parsed.embedding; dal.upsertEmbeddings(std::vector{emb}); } } if (hasTx) dal.commit(); } catch (const std::exception& ex) { if (hasTx) dal.rollback(); return detail::error_response("internal_error", ex.what()); } std::ostringstream os; os << "{\"upserted\":" << ids.size(); if (!ids.empty()) { os << ",\"ids\":["; for (std::size_t i = 0; i < ids.size(); ++i) { if (i) os << ","; os << "\"" << detail::json_escape(ids[i]) << "\""; } os << "]"; } os << ",\"status\":\"ok\"}"; return os.str(); } inline std::string search_memory(const std::string& reqJson) { const std::string nsName = detail::extract_string_field(reqJson, "namespace"); if (nsName.empty()) { return detail::error_response("bad_request", "namespace is required"); } auto nsRow = detail::database().findNamespace(nsName); if (!nsRow) { return "{\"matches\":[]}"; } std::string queryText; std::vector queryEmbedding; int limit = 5; if (auto queryObj = detail::find_object_segment(reqJson, "query")) { queryText = detail::extract_string_field(*queryObj, "text"); queryEmbedding = detail::parse_float_array(*queryObj, "embedding"); if (auto k = detail::extract_int_field(*queryObj, "k")) { if (*k > 0) limit = *k; } } kom::PgDal& dal = detail::database(); std::unordered_set seen; std::vector matches; auto textRows = dal.searchText(nsRow->id, queryText, limit); for (std::size_t idx = 0; idx < textRows.size(); ++idx) { const auto& row = textRows[idx]; detail::SearchMatch match; match.id = row.id; match.text = row.text; match.score = 1.0f - static_cast(idx) * 0.05f; matches.push_back(match); seen.insert(match.id); if (static_cast(matches.size()) >= limit) break; } if (static_cast(matches.size()) < limit && !queryEmbedding.empty()) { auto vectorMatches = dal.searchVector(nsRow->id, queryEmbedding, limit); for (const auto& pair : vectorMatches) { if (seen.count(pair.first)) continue; auto item = dal.getItemById(pair.first); if (!item) continue; detail::SearchMatch match; match.id = pair.first; match.score = pair.second; match.text = item->text; matches.push_back(match); if (static_cast(matches.size()) >= limit) break; } } return detail::serialize_matches(matches); } } // namespace Handlers