metal-kompanion/src/mcp/HandlersMemory.hpp

435 lines
14 KiB
C++

#pragma once
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <iomanip>
#include <optional>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include "dal/PgDal.hpp"
namespace Handlers {
namespace detail {
inline kom::PgDal& database() {
static kom::PgDal instance;
static bool connected = [] {
const char* env = std::getenv("PG_DSN");
const std::string dsn = (env && *env) ? std::string(env) : std::string();
if (!dsn.empty()) {
return instance.connect(dsn);
}
return instance.connect("stub://memory");
}();
(void)connected;
return instance;
}
inline std::string json_escape(const std::string& in) {
std::ostringstream os;
for (char c : in) {
switch (c) {
case '\"': os << "\\\""; break;
case '\\': os << "\\\\"; break;
case '\n': os << "\\n"; break;
case '\r': os << "\\r"; break;
case '\t': os << "\\t"; break;
default: os << c; break;
}
}
return os.str();
}
inline std::string error_response(const std::string& code, const std::string& message) {
std::ostringstream os;
os << "{\"error\":{\"code\":\"" << json_escape(code)
<< "\",\"message\":\"" << json_escape(message) << "\"}}";
return os.str();
}
inline std::optional<std::string> find_delimited_segment(const std::string& json,
const std::string& key,
char open,
char close) {
const std::string pattern = "\"" + key + "\"";
auto pos = json.find(pattern);
if (pos == std::string::npos) return std::nullopt;
pos = json.find(open, pos);
if (pos == std::string::npos) return std::nullopt;
int depth = 0;
bool inString = false;
bool escape = false;
std::size_t start = std::string::npos;
for (std::size_t i = pos; i < json.size(); ++i) {
char c = json[i];
if (escape) { escape = false; continue; }
if (c == '\\') {
if (inString) escape = true;
continue;
}
if (c == '\"') {
inString = !inString;
continue;
}
if (inString) continue;
if (c == open) {
if (depth == 0) start = i;
++depth;
continue;
}
if (c == close) {
--depth;
if (depth == 0 && start != std::string::npos) {
return json.substr(start, i - start + 1);
}
}
}
return std::nullopt;
}
inline std::optional<std::string> find_object_segment(const std::string& json, const std::string& key) {
return find_delimited_segment(json, key, '{', '}');
}
inline std::optional<std::string> find_array_segment(const std::string& json, const std::string& key) {
return find_delimited_segment(json, key, '[', ']');
}
inline std::string extract_string_field(const std::string& json, const std::string& key) {
const std::string pattern = "\"" + key + "\"";
auto pos = json.find(pattern);
if (pos == std::string::npos) return {};
pos = json.find(':', pos);
if (pos == std::string::npos) return {};
++pos;
while (pos < json.size() && std::isspace(static_cast<unsigned char>(json[pos]))) ++pos;
if (pos >= json.size() || json[pos] != '\"') return {};
++pos;
std::ostringstream os;
bool escape = false;
for (; pos < json.size(); ++pos) {
char c = json[pos];
if (escape) {
switch (c) {
case '\"': os << '\"'; break;
case '\\': os << '\\'; break;
case '/': os << '/'; break;
case 'b': os << '\b'; break;
case 'f': os << '\f'; break;
case 'n': os << '\n'; break;
case 'r': os << '\r'; break;
case 't': os << '\t'; break;
default: os << c; break;
}
escape = false;
continue;
}
if (c == '\\') { escape = true; continue; }
if (c == '\"') break;
os << c;
}
return os.str();
}
inline std::optional<int> extract_int_field(const std::string& json, const std::string& key) {
const std::string pattern = "\"" + key + "\"";
auto pos = json.find(pattern);
if (pos == std::string::npos) return std::nullopt;
pos = json.find(':', pos);
if (pos == std::string::npos) return std::nullopt;
++pos;
while (pos < json.size() && std::isspace(static_cast<unsigned char>(json[pos]))) ++pos;
std::size_t start = pos;
while (pos < json.size() && (std::isdigit(static_cast<unsigned char>(json[pos])) || json[pos] == '-')) ++pos;
if (start == pos) return std::nullopt;
try {
return std::stoi(json.substr(start, pos - start));
} catch (...) {
return std::nullopt;
}
}
inline std::vector<std::string> parse_object_array(const std::string& json, const std::string& key) {
std::vector<std::string> objects;
auto segment = find_array_segment(json, key);
if (!segment) return objects;
const std::string& arr = *segment;
int depth = 0;
bool inString = false;
bool escape = false;
std::size_t start = std::string::npos;
for (std::size_t i = 0; i < arr.size(); ++i) {
char c = arr[i];
if (escape) { escape = false; continue; }
if (c == '\\') {
if (inString) escape = true;
continue;
}
if (c == '\"') {
inString = !inString;
continue;
}
if (inString) continue;
if (c == '{') {
if (depth == 0) start = i;
++depth;
} else if (c == '}') {
--depth;
if (depth == 0 && start != std::string::npos) {
objects.push_back(arr.substr(start, i - start + 1));
start = std::string::npos;
}
}
}
return objects;
}
inline std::vector<std::string> parse_string_array(const std::string& json, const std::string& key) {
std::vector<std::string> values;
auto segment = find_array_segment(json, key);
if (!segment) return values;
const std::string& arr = *segment;
bool inString = false;
bool escape = false;
std::ostringstream current;
for (std::size_t i = 0; i < arr.size(); ++i) {
char c = arr[i];
if (escape) {
switch (c) {
case '\"': current << '\"'; break;
case '\\': current << '\\'; break;
case 'n': current << '\n'; break;
case 'r': current << '\r'; break;
case 't': current << '\t'; break;
default: current << c; break;
}
escape = false;
continue;
}
if (c == '\\' && inString) { escape = true; continue; }
if (c == '\"') {
if (inString) {
values.push_back(current.str());
current.str("");
current.clear();
}
inString = !inString;
continue;
}
if (inString) current << c;
}
return values;
}
inline std::vector<float> parse_float_array(const std::string& json, const std::string& key) {
std::vector<float> values;
auto segment = find_array_segment(json, key);
if (!segment) return values;
const std::string& arr = *segment;
std::size_t pos = 0;
while (pos < arr.size()) {
while (pos < arr.size() && !std::isdigit(static_cast<unsigned char>(arr[pos])) && arr[pos] != '-' && arr[pos] != '+') ++pos;
if (pos >= arr.size()) break;
std::size_t end = pos;
while (end < arr.size() && (std::isdigit(static_cast<unsigned char>(arr[end])) || arr[end] == '.' || arr[end] == 'e' || arr[end] == 'E' || arr[end] == '+' || arr[end] == '-')) ++end;
try {
values.push_back(std::stof(arr.substr(pos, end - pos)));
} catch (...) {
// Skip unparsable token.
}
pos = end;
}
return values;
}
inline std::string format_score(float score) {
std::ostringstream os;
os.setf(std::ios::fixed);
os << std::setprecision(3) << score;
return os.str();
}
struct ParsedItem {
std::string id;
std::string text;
std::vector<std::string> tags;
std::vector<float> embedding;
std::string rawJson;
};
inline std::vector<ParsedItem> parse_items(const std::string& json) {
std::vector<ParsedItem> items;
for (const auto& obj : parse_object_array(json, "items")) {
ParsedItem item;
item.rawJson = obj;
item.id = extract_string_field(obj, "id");
item.text = extract_string_field(obj, "text");
item.tags = parse_string_array(obj, "tags");
item.embedding = parse_float_array(obj, "embedding");
items.push_back(std::move(item));
}
return items;
}
struct SearchMatch {
std::string id;
float score;
std::optional<std::string> text;
};
inline std::string serialize_matches(const std::vector<SearchMatch>& matches) {
std::ostringstream os;
os << "{\"matches\":[";
for (std::size_t i = 0; i < matches.size(); ++i) {
const auto& match = matches[i];
if (i) os << ",";
os << "{\"id\":\"" << json_escape(match.id) << "\""
<< ",\"score\":" << format_score(match.score);
if (match.text) {
os << ",\"text\":\"" << json_escape(*match.text) << "\"";
}
os << "}";
}
os << "]}";
return os.str();
}
} // namespace detail
inline std::string upsert_memory(const std::string& reqJson) {
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
if (nsName.empty()) {
return detail::error_response("bad_request", "namespace is required");
}
auto nsRow = detail::database().ensureNamespace(nsName);
if (!nsRow) {
return detail::error_response("internal_error", "failed to ensure namespace");
}
auto items = detail::parse_items(reqJson);
if (items.empty()) {
return detail::error_response("bad_request", "items array must contain at least one entry");
}
kom::PgDal& dal = detail::database();
const bool hasTx = dal.begin();
std::vector<std::string> ids;
ids.reserve(items.size());
try {
for (auto& parsed : items) {
kom::ItemRow row;
row.id = parsed.id;
row.namespace_id = nsRow->id;
row.content_json = parsed.rawJson;
row.text = parsed.text.empty() ? std::optional<std::string>() : std::optional<std::string>(parsed.text);
row.tags = parsed.tags;
row.revision = 1;
const std::string itemId = dal.upsertItem(row);
ids.push_back(itemId);
if (!parsed.embedding.empty()) {
kom::ChunkRow chunk;
chunk.item_id = itemId;
chunk.ord = 0;
chunk.text = parsed.text;
auto chunkIds = dal.upsertChunks(std::vector<kom::ChunkRow>{chunk});
kom::EmbeddingRow emb;
emb.chunk_id = chunkIds.front();
emb.model = "stub-model";
emb.dim = static_cast<int>(parsed.embedding.size());
emb.vector = parsed.embedding;
dal.upsertEmbeddings(std::vector<kom::EmbeddingRow>{emb});
}
}
if (hasTx) dal.commit();
} catch (const std::exception& ex) {
if (hasTx) dal.rollback();
return detail::error_response("internal_error", ex.what());
}
std::ostringstream os;
os << "{\"upserted\":" << ids.size();
if (!ids.empty()) {
os << ",\"ids\":[";
for (std::size_t i = 0; i < ids.size(); ++i) {
if (i) os << ",";
os << "\"" << detail::json_escape(ids[i]) << "\"";
}
os << "]";
}
os << ",\"status\":\"ok\"}";
return os.str();
}
inline std::string search_memory(const std::string& reqJson) {
const std::string nsName = detail::extract_string_field(reqJson, "namespace");
if (nsName.empty()) {
return detail::error_response("bad_request", "namespace is required");
}
auto nsRow = detail::database().findNamespace(nsName);
if (!nsRow) {
return "{\"matches\":[]}";
}
std::string queryText;
std::vector<float> queryEmbedding;
int limit = 5;
if (auto queryObj = detail::find_object_segment(reqJson, "query")) {
queryText = detail::extract_string_field(*queryObj, "text");
queryEmbedding = detail::parse_float_array(*queryObj, "embedding");
if (auto k = detail::extract_int_field(*queryObj, "k")) {
if (*k > 0) limit = *k;
}
}
kom::PgDal& dal = detail::database();
std::unordered_set<std::string> seen;
std::vector<detail::SearchMatch> matches;
auto textRows = dal.searchText(nsRow->id, queryText, limit);
for (std::size_t idx = 0; idx < textRows.size(); ++idx) {
const auto& row = textRows[idx];
detail::SearchMatch match;
match.id = row.id;
match.text = row.text;
match.score = 1.0f - static_cast<float>(idx) * 0.05f;
matches.push_back(match);
seen.insert(match.id);
if (static_cast<int>(matches.size()) >= limit) break;
}
if (static_cast<int>(matches.size()) < limit && !queryEmbedding.empty()) {
auto vectorMatches = dal.searchVector(nsRow->id, queryEmbedding, limit);
for (const auto& pair : vectorMatches) {
if (seen.count(pair.first)) continue;
auto item = dal.getItemById(pair.first);
if (!item) continue;
detail::SearchMatch match;
match.id = pair.first;
match.score = pair.second;
match.text = item->text;
matches.push_back(match);
if (static_cast<int>(matches.size()) >= limit) break;
}
}
return detail::serialize_matches(matches);
}
} // namespace Handlers