optionally search one language only

This commit is contained in:
Andreas Gohr
2023-08-28 18:09:15 +02:00
parent 754b83949a
commit e33a1d7adc
12 changed files with 184 additions and 29 deletions

16
AIChat.php Normal file
View File

@ -0,0 +1,16 @@
<?php
namespace dokuwiki\plugin\aichat;
/**
* AIChat constants
*/
class AIChat
{
/** @var int preferUIlanguage config: guess language use, all sources */
const LANG_AUTO_ALL = 0;
/** @var int preferUIlanguage config: use UI language, all sources */
const LANG_UI_ALL = 1;
/** @var int preferUIlanguage config: use UI language, limit sources */
const LANG_UI_LIMITED = 2;
}

View File

@ -16,6 +16,8 @@ class Chunk implements \JsonSerializable
protected $created;
/** @var int */
protected $score;
/** @var string */
protected $language;
/**
* @param string $page
@ -24,12 +26,13 @@ class Chunk implements \JsonSerializable
* @param float[] $embedding
* @param int $created
*/
public function __construct($page, $id, $text, $embedding, $created = '', $score = 0)
public function __construct($page, $id, $text, $embedding, $lang = '', $created = '', $score = 0)
{
$this->page = $page;
$this->id = $id;
$this->text = $text;
$this->embedding = $embedding;
$this->language = $lang ?: $this->determineLanguage();
$this->created = $created ?: time();
$this->score = $score;
}
@ -135,6 +138,42 @@ class Chunk implements \JsonSerializable
$this->score = $score;
}
/**
* @return string
*/
public function getLanguage(): string
{
return $this->language;
}
/**
* @param string $language
*/
public function setLanguage($language): void
{
$this->language = $language;
}
/**
* Initialize the language of the chunk
*
* When the translation plugin is available it is used to determine the language, otherwise the default language
* is used.
*
* @return string The lanuaage code
*/
protected function determineLanguage()
{
global $conf;
/** @var \helper_plugin_translation $trans */
$trans = plugin_load('helper', 'translation');
if ($trans) {
$lc = $trans->realLC($trans->getLangPart($this->page));
} else {
$lc = $conf['lang'];
}
return $lc;
}
/**
@ -151,6 +190,7 @@ class Chunk implements \JsonSerializable
$data['id'],
$data['text'],
$data['embedding'],
$data['language'] ?? '',
$data['created']
);
}
@ -163,6 +203,7 @@ class Chunk implements \JsonSerializable
'id' => $this->id,
'text' => $this->text,
'embedding' => $this->embedding,
'language' => $this->language,
'created' => $this->created,
];
}

View File

@ -177,10 +177,11 @@ class Embeddings
* The number of returned chunks depends on the MAX_CONTEXT_LEN setting.
*
* @param string $query The question
* @param string $lang Limit results to this language
* @return Chunk[]
* @throws \Exception
*/
public function getSimilarChunks($query)
public function getSimilarChunks($query, $lang='')
{
global $auth;
$vector = $this->model->getEmbedding($query);
@ -191,7 +192,7 @@ class Embeddings
);
$time = microtime(true);
$chunks = $this->storage->getSimilarChunks($vector, $fetch);
$chunks = $this->storage->getSimilarChunks($vector, $lang, $fetch);
if ($this->logger) {
$this->logger->info(
'Fetched {count} similar chunks from store in {time} seconds',

View File

@ -108,10 +108,11 @@ abstract class AbstractStorage
* If not, the storage should return twice the $limit of chunks and the caller will filter out the readable ones.
*
* @param float[] $vector The vector to compare to
* @param string $lang Limit results to this language. When empty consider all languages
* @param int $limit The number of results to return, see note above
* @return Chunk[]
*/
abstract public function getSimilarChunks($vector, $limit = 4);
abstract public function getSimilarChunks($vector, $lang='', $limit = 4);
/**
* Get information about the storage

View File

@ -88,6 +88,7 @@ class PineconeStorage extends AbstractStorage
$chunkID,
$vector['metadata']['text'],
$vector['values'],
$vector['metadata']['language'] ?? '',
$vector['metadata']['created']
);
}
@ -186,6 +187,7 @@ class PineconeStorage extends AbstractStorage
$vector['id'],
$vector['metadata']['text'],
$vector['values'],
$vector['metadata']['language'] ?? '',
$vector['metadata']['created']
);
}
@ -193,10 +195,16 @@ class PineconeStorage extends AbstractStorage
}
/** @inheritdoc */
public function getSimilarChunks($vector, $limit = 4)
public function getSimilarChunks($vector, $lang = '', $limit = 4)
{
$limit = $limit * 2; // we can't check ACLs, so we return more than requested
if ($lang) {
$filter = ['language' => ['$eq', $lang]];
} else {
$filter = [];
}
$response = $this->runQuery(
'/query',
[
@ -204,6 +212,7 @@ class PineconeStorage extends AbstractStorage
'topK' => (int)$limit,
'include_metadata' => true,
'include_values' => true,
'filter' => $filter,
]
);
$chunks = [];
@ -213,6 +222,7 @@ class PineconeStorage extends AbstractStorage
$vector['id'],
$vector['metadata']['text'],
$vector['values'],
$vector['metadata']['language'] ?? '',
$vector['metadata']['created'],
$vector['score']
);

View File

@ -2,10 +2,10 @@
namespace dokuwiki\plugin\aichat\Storage;
use dokuwiki\plugin\aichat\AIChat;
use dokuwiki\plugin\aichat\Chunk;
use dokuwiki\plugin\sqlite\SQLiteDB;
use KMeans\Cluster;
use KMeans\Point;
use KMeans\Space;
/**
@ -26,6 +26,8 @@ class SQLiteStorage extends AbstractStorage
/** @var SQLiteDB */
protected $db;
protected $useLanguageClusters = false;
/**
* Initializes the database connection and registers our custom function
*
@ -35,6 +37,9 @@ class SQLiteStorage extends AbstractStorage
{
$this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
$this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2);
$helper = plugin_load('helper', 'aichat');
$this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED;
}
/** @inheritdoc */
@ -48,6 +53,7 @@ class SQLiteStorage extends AbstractStorage
$record['id'],
$record['chunk'],
json_decode($record['embedding'], true),
$record['lang'],
$record['created']
);
}
@ -82,7 +88,8 @@ class SQLiteStorage extends AbstractStorage
'id' => $chunk->getId(),
'chunk' => $chunk->getText(),
'embedding' => json_encode($chunk->getEmbedding()),
'created' => $chunk->getCreated()
'created' => $chunk->getCreated(),
'lang' => $chunk->getLanguage(),
]);
}
}
@ -119,6 +126,7 @@ class SQLiteStorage extends AbstractStorage
$record['id'],
$record['chunk'],
json_decode($record['embedding'], true),
$record['lang'],
$record['created']
);
}
@ -126,9 +134,9 @@ class SQLiteStorage extends AbstractStorage
}
/** @inheritdoc */
public function getSimilarChunks($vector, $limit = 4)
public function getSimilarChunks($vector, $lang = '', $limit = 4)
{
$cluster = $this->getCluster($vector);
$cluster = $this->getCluster($vector, $lang);
if ($this->logger) $this->logger->info(
'Using cluster {cluster} for similarity search', ['cluster' => $cluster]
);
@ -150,6 +158,7 @@ class SQLiteStorage extends AbstractStorage
$record['id'],
$record['chunk'],
json_decode($record['embedding'], true),
$record['lang'],
$record['created'],
$record['similarity']
);
@ -164,7 +173,7 @@ class SQLiteStorage extends AbstractStorage
$size = $this->db->queryValue(
'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
);
$query = "SELECT cluster, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster";
$query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster";
$clusters = $this->db->queryKeyValueList($query);
return [
@ -208,27 +217,52 @@ class SQLiteStorage extends AbstractStorage
/**
* Create new clusters based on random chunks
*
* @noinspection SqlWithoutWhere
* @return void
*/
protected function createClusters()
{
if ($this->logger) $this->logger->info('Creating new clusters...');
if($this->useLanguageClusters) {
$result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings');
$langs = array_column($result, 'lang');
foreach ($langs as $lang) {
$this->createLanguageClusters($lang);
}
} else {
$this->createLanguageClusters('');
}
}
/**
* Create new clusters based on random chunks for the given Language
*
* @param string $lang The language to cluster, empty when all languages go into the same cluster
* @noinspection SqlWithoutWhere
*/
protected function createLanguageClusters($lang)
{
if($lang != '') {
$where = 'WHERE lang = '. $this->db->getPdo()->quote($lang);
} else {
$where = '';
}
if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]);
$this->db->getPdo()->beginTransaction();
try {
// clean up old cluster data
$query = 'DELETE FROM clusters';
$query = "DELETE FROM clusters $where";
$this->db->exec($query);
$query = 'UPDATE embeddings SET cluster = NULL';
$query = "UPDATE embeddings SET cluster = NULL $where";
$this->db->exec($query);
// get a random selection of chunks
$query = 'SELECT id, embedding FROM embeddings ORDER BY RANDOM() LIMIT ?';
$query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?";
$result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
if (!$result) return; // no data to cluster
$dimensions = count(json_decode($result[0]['embedding'], true));
// get the number of all chunks, to calculate the number of clusters
$query = 'SELECT COUNT(*) FROM embeddings';
$query = "SELECT COUNT(*) FROM embeddings $where";
$total = $this->db->queryValue($query);
$clustercount = ceil($total / self::CLUSTER_SIZE);
if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
@ -253,15 +287,15 @@ class SQLiteStorage extends AbstractStorage
foreach ($clusters as $clusterID => $cluster) {
/** @var Cluster $cluster */
$centroid = $cluster->getCoordinates();
$query = 'INSERT INTO clusters (cluster, centroid) VALUES (?, ?)';
$this->db->exec($query, [$clusterID, json_encode($centroid)]);
$query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)';
$this->db->exec($query, [$lang, json_encode($centroid)]);
}
$this->db->getPdo()->commit();
if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
} catch (\Exception $e) {
$this->db->getPdo()->rollBack();
throw new \RuntimeException('Clustering failed', 0, $e);
throw new \RuntimeException('Clustering failed: '.$e->getMessage(), 0, $e);
}
}
@ -273,12 +307,12 @@ class SQLiteStorage extends AbstractStorage
protected function setChunkClusters()
{
if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
$query = 'SELECT id, embedding FROM embeddings WHERE cluster IS NULL';
$query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL';
$handle = $this->db->query($query);
while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
$vector = json_decode($record['embedding'], true);
$cluster = $this->getCluster($vector);
$cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : '');
$query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
$this->db->exec($query, [$cluster, $record['id']]);
if ($this->logger) $this->logger->success(
@ -294,9 +328,20 @@ class SQLiteStorage extends AbstractStorage
* @param float[] $vector
* @return int|null
*/
protected function getCluster($vector)
protected function getCluster($vector, $lang)
{
$query = 'SELECT cluster, centroid FROM clusters ORDER BY COSIM(centroid, ?) DESC LIMIT 1';
if($lang != '') {
$where = 'WHERE lang = '. $this->db->getPdo()->quote($lang);
} else {
$where = '';
}
$query = "SELECT cluster, centroid
FROM clusters
$where
ORDER BY COSIM(centroid, ?) DESC
LIMIT 1";
$result = $this->db->queryRecord($query, [json_encode($vector)]);
if (!$result) return null;
return $result['cluster'];

View File

@ -236,7 +236,12 @@ class cli_plugin_aichat extends CLIPlugin
*/
protected function similar($query)
{
$sources = $this->helper->getEmbeddings()->getSimilarChunks($query);
$langlimit = $this->helper->getLanguageLimit();
if ($langlimit) {
$this->info('Limiting results to {lang}', ['lang' => $langlimit]);
}
$sources = $this->helper->getEmbeddings()->getSimilarChunks($query, $langlimit);
$this->printSources($sources);
}

View File

@ -22,4 +22,8 @@ $meta['pinecone_baseurl'] = array('string');
$meta['logging'] = array('onoff');
$meta['restrict'] = array('string');
$meta['preferUIlanguage'] = array('onoff');
$meta['preferUIlanguage'] = array('multichoice', '_choices' => array(
\dokuwiki\plugin\aichat\AIChat::LANG_AUTO_ALL,
\dokuwiki\plugin\aichat\AIChat::LANG_UI_ALL,
\dokuwiki\plugin\aichat\AIChat::LANG_UI_LIMITED,
));

View File

@ -1 +1 @@
2
3

12
db/update0003.sql Normal file
View File

@ -0,0 +1,12 @@
ALTER TABLE embeddings ADD COLUMN lang NOT NULL DEFAULT '';
CREATE INDEX embeddings_lang_idx ON embeddings (lang);
DROP TABLE clusters;
CREATE TABLE clusters
(
cluster INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
lang TEXT NOT NULL DEFAULT '',
centroid BLOB NOT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX clusters_lang_idx ON clusters (lang);

View File

@ -1,6 +1,7 @@
<?php
use dokuwiki\Extension\CLIPlugin;
use dokuwiki\plugin\aichat\AIChat;
use dokuwiki\plugin\aichat\Chunk;
use dokuwiki\plugin\aichat\Embeddings;
use dokuwiki\plugin\aichat\Model\AbstractModel;
@ -149,7 +150,7 @@ class helper_plugin_aichat extends \dokuwiki\Extension\Plugin
*/
public function askQuestion($question, $previous = [])
{
$similar = $this->getEmbeddings()->getSimilarChunks($question);
$similar = $this->getEmbeddings()->getSimilarChunks($question, $this->getLanguageLimit());
if ($similar) {
$context = implode("\n", array_map(function (Chunk $chunk) {
return "\n```\n" . $chunk->getText() . "\n```\n";
@ -254,7 +255,7 @@ class helper_plugin_aichat extends \dokuwiki\Extension\Plugin
{
global $conf;
if ($this->getConf('preferUIlanguage')) {
if ($this->getConf('preferUIlanguage') > AIChat::LANG_AUTO_ALL) {
$isoLangnames = include(__DIR__ . '/lang/languages.php');
if (isset($isoLangnames[$conf['lang']])) {
$languagePrompt = 'Always answer in ' . $isoLangnames[$conf['lang']] . '.';
@ -265,5 +266,20 @@ class helper_plugin_aichat extends \dokuwiki\Extension\Plugin
$languagePrompt = 'Always answer in the user\'s language.';
return $languagePrompt;
}
/**
* Should sources be limited to current language?
*
* @return string The current language code or empty string
*/
public function getLanguageLimit()
{
if ($this->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED) {
global $conf;
return $conf['lang'];
} else {
return '';
}
}
}

View File

@ -10,4 +10,8 @@ $lang['openaiorg'] = 'Your OpenAI organization ID (if any)';
$lang['model'] = 'Which model to use. When changing models, be sure to run <code>php bin/plugin.php aichat embed -c</code> to rebuild the vector storage.';
$lang['logging'] = 'Log all questions and answers. Use the <a href="?do=admin&page=logviewer&facility=aichat">Log Viewer</a> to access.';
$lang['restrict'] = 'Restrict access to these users and groups (comma separated). Leave empty to allow all users.';
$lang['preferUIlanguage'] = 'Prefer the configured UI language when answering questions instead of guessing which language the user used in their question.';
$lang['preferUIlanguage'] = 'How to work with multilingual wikis? (Requires the translation plugin)';
$lang['preferUIlanguage_o_0'] = 'Guess language, use all sources';
$lang['preferUIlanguage_o_1'] = 'Prefer UI language, use all sources';
$lang['preferUIlanguage_o_2'] = 'Prefer UI language, same language sources only';