ki: implement multi-text embeddings in OllamaProvider::embed by batching one request per text and aggregating results
This commit is contained in:
parent
e1eb3afacc
commit
2ee757b64a
|
|
@ -108,43 +108,43 @@ QFuture<KIReply*> OllamaProvider::chat(const KIThread& thread, const KIChatOptio
|
|||
|
||||
QFuture<KIEmbeddingResult> OllamaProvider::embed(const QStringList& texts, const KIEmbedOptions& opts)
|
||||
{
|
||||
QNetworkRequest req{QUrl(QStringLiteral("http://localhost:11434/api/embeddings"))};
|
||||
// Execute one request per input text; aggregate outputs.
|
||||
QFutureInterface<KIEmbeddingResult> fi;
|
||||
fi.reportStarted();
|
||||
if (texts.isEmpty()) { KIEmbeddingResult r; r.model = opts.model; fi.reportResult(r); fi.reportFinished(); return fi.future(); }
|
||||
|
||||
struct Accum { QVector<QVector<float>> vectors; int remaining = 0; QString model; };
|
||||
auto acc = new Accum();
|
||||
acc->vectors.resize(texts.size());
|
||||
acc->remaining = texts.size();
|
||||
|
||||
const QUrl url(QStringLiteral("http://localhost:11434/api/embeddings"));
|
||||
for (int i = 0; i < texts.size(); ++i) {
|
||||
QNetworkRequest req{url};
|
||||
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||||
|
||||
QJsonObject data;
|
||||
data["model"] = opts.model;
|
||||
data["prompt"] = texts.join("\n"); // Join all texts into a single prompt
|
||||
|
||||
auto netReply = m_manager->post(req, QJsonDocument(data).toJson());
|
||||
|
||||
QFutureInterface<KIEmbeddingResult> interface;
|
||||
interface.reportStarted();
|
||||
|
||||
connect(netReply, &QNetworkReply::finished, this, [netReply, interface]() mutable {
|
||||
if (netReply->error() != QNetworkReply::NoError) {
|
||||
// TODO: Handle error
|
||||
interface.reportFinished();
|
||||
netReply->deleteLater();
|
||||
return;
|
||||
const QJsonObject body{ {QStringLiteral("model"), opts.model}, {QStringLiteral("prompt"), texts[i]} };
|
||||
auto rep = m_manager->post(req, QJsonDocument(body).toJson());
|
||||
connect(rep, &QNetworkReply::finished, this, [rep, i, acc, fi]() mutable {
|
||||
if (rep->error() == QNetworkReply::NoError) {
|
||||
const auto obj = QJsonDocument::fromJson(rep->readAll()).object();
|
||||
if (acc->model.isEmpty()) acc->model = obj.value(QStringLiteral("model")).toString();
|
||||
const auto arr = obj.value(QStringLiteral("embedding")).toArray();
|
||||
QVector<float> vec; vec.reserve(arr.size());
|
||||
for (const auto &v : arr) vec.push_back(static_cast<float>(v.toDouble()));
|
||||
acc->vectors[i] = std::move(vec);
|
||||
}
|
||||
|
||||
const auto json = QJsonDocument::fromJson(netReply->readAll());
|
||||
const auto embeddingArray = json["embedding"].toArray();
|
||||
|
||||
KIEmbeddingResult result;
|
||||
QVector<float> embedding;
|
||||
for (const QJsonValue &value : embeddingArray) {
|
||||
embedding.push_back(value.toDouble());
|
||||
rep->deleteLater();
|
||||
acc->remaining -= 1;
|
||||
if (acc->remaining == 0) {
|
||||
KIEmbeddingResult res; res.vectors = std::move(acc->vectors); res.model = acc->model;
|
||||
fi.reportResult(res);
|
||||
fi.reportFinished();
|
||||
delete acc;
|
||||
}
|
||||
result.vectors.push_back(embedding);
|
||||
result.model = json["model"].toString();
|
||||
|
||||
interface.reportResult(result);
|
||||
interface.reportFinished();
|
||||
netReply->deleteLater();
|
||||
});
|
||||
}
|
||||
|
||||
return interface.future();
|
||||
return fi.future();
|
||||
}
|
||||
|
||||
void OllamaProvider::cancel(quint64 requestId)
|
||||
|
|
|
|||
Loading…
Reference in New Issue