Files
nextcloud-recognize/lib/Clustering/MrdBallTree.php
Marcel Klehr b0d4ea8610 fix(psalm): Fix psalm issue
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
2025-04-09 13:31:14 +02:00

679 lines
23 KiB
PHP

<?php
/*
* Copyright (c) 2023 The Recognize contributors.
* This file is licensed under the Affero General Public License version 3 or later. See the COPYING file.
*/
declare(strict_types=1);
namespace OCA\Recognize\Clustering;
use \OCA\Recognize\Vendor\Rubix\ML\Datasets\Labeled;
use \OCA\Recognize\Vendor\Rubix\ML\Graph\Nodes\Ball;
use \OCA\Recognize\Vendor\Rubix\ML\Graph\Nodes\Hypersphere;
use \OCA\Recognize\Vendor\Rubix\ML\Graph\Trees\BallTree;
use \OCA\Recognize\Vendor\Rubix\ML\Kernels\Distance\Distance;
final class MrdBallTree extends BallTree {
private ?Labeled $dataset = null;
private array $nativeInterpointCache = [];
/**
* @var array<array-key,list<float>>
*/
private array $coreDistances = [];
/**
* @var array<array-key,list<float>>
*/
private array $coreNeighborDistances = [];
private int $sampleSize;
private array $nodeDistances;
private \SplObjectStorage $nodeIds;
private float $radiusDiamFactor;
/**
* @param int $maxLeafSize
* @param int $coreDistSampleSize
* @param \OCA\Recognize\Vendor\Rubix\ML\Kernels\Distance\Distance|null $kernel
* @throws \OCA\Recognize\Vendor\Rubix\ML\Exceptions\InvalidArgumentException
*/
public function __construct(int $maxLeafSize = 30, int $sampleSize = 5, ?Distance $kernel = null) {
if ($maxLeafSize < 1) {
throw new \InvalidArgumentException('At least one sample is required'
. " to form a leaf node, $maxLeafSize given.");
}
$this->maxLeafSize = $maxLeafSize;
$this->sampleSize = $sampleSize;
$this->kernel = $kernel ?? new SquaredDistance();
$this->radiusDiamFactor = $this->kernel instanceof SquaredDistance ? 4 : 2;
$this->nodeDistances = [];
$this->nodeIds = new \SplObjectStorage();
}
public function getCoreNeighborDistances(): array {
return $this->coreNeighborDistances;
}
/**
* @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode
* @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode
* @return float
*/
public function nodeDistance($queryNode, $referenceNode): float {
// Use cache to accelerate repeated queries
if ($this->nodeIds->contains($queryNode)) {
$queryNodeId = $this->nodeIds[$queryNode];
} else {
$queryNodeId = $this->nodeIds->count();
$this->nodeIds[$queryNode] = $queryNodeId;
}
if ($this->nodeIds->contains($referenceNode)) {
$referenceNodeId = $this->nodeIds[$referenceNode];
} else {
$referenceNodeId = $this->nodeIds->count();
$this->nodeIds[$referenceNode] = $referenceNodeId;
}
if ($referenceNodeId === $queryNodeId) {
return (-$this->radiusDiamFactor * $queryNode->radius());
}
$smallIndex = min($queryNodeId, $referenceNodeId);
$largeIndex = max($queryNodeId, $referenceNodeId);
if (isset($this->nodeDistances[$smallIndex][$largeIndex])) {
$nodeDistance = $this->nodeDistances[$smallIndex][$largeIndex];
} else {
$nodeDistance = $this->kernel->compute($queryNode->center(), $referenceNode->center());
if ($this->kernel instanceof SquaredDistance) {
$nodeDistance = sqrt($nodeDistance) - sqrt($queryNode->radius()) - sqrt($referenceNode->radius());
$nodeDistance = abs($nodeDistance) * $nodeDistance;
} else {
$nodeDistance = $nodeDistance - $queryNode->radius() - $referenceNode->radius();
}
$this->nodeDistances[$smallIndex][$largeIndex] = $nodeDistance;
}
return $nodeDistance;
}
/**
* Get tree root.
*
* @internal
*
* @return DualTreeBall
*/
public function getRoot(): Ball {
return $this->root;
}
/**
* Get the dataset the tree was grown on.
*
* @internal
*
* @return Labeled
*/
public function getDataset(): Labeled {
return $this->dataset;
}
/**
* @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $queryNode
* @param \OCA\Recognize\Clustering\DualTreeBall|\OCA\Recognize\Clustering\DualTreeClique $referenceNode
* @param int $k
* @param float $maxRange
* @param array<array-key, float> $bestDistances
* @return void
*/
private function updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void {
$querySamples = $queryNode->dataset()->samples();
$queryLabels = $queryNode->dataset()->labels();
$referenceSamples = $referenceNode->dataset()->samples();
$referenceLabels = $referenceNode->dataset()->labels();
$longestDistance = 0.0;
$shortestDistance = INF;
foreach ($querySamples as $queryKey => $querySample) {
$queryLabel = $queryLabels[$queryKey];
$bestDistance = $bestDistances[$queryLabel];
foreach ($referenceSamples as $referenceKey => $referenceSample) {
$referenceLabel = $referenceLabels[$referenceKey];
if ($queryLabel === $referenceLabel) {
continue;
}
// Calculate native distance
$distance = $this->cachedComputeNative($queryLabel, $querySample, $referenceLabel, $referenceSample);
if ($distance < $bestDistance) {
//Minimize array queries within these loops:
$coreNeighborDistances = & $this->coreNeighborDistances[$queryLabel];
$coreNeighborDistances[$referenceLabel] = $distance;
if (count($coreNeighborDistances) >= $k) {
asort($coreNeighborDistances);
$coreNeighborDistances = array_slice($coreNeighborDistances, 0, $k, true);
$bestDistance = (float) min(end($coreNeighborDistances), $maxRange);
}
}
}
if ($bestDistance > $longestDistance) {
$longestDistance = $bestDistance;
}
if ($bestDistance < $shortestDistance) {
$shortestDistance = $bestDistance;
}
$bestDistances[$queryLabel] = $bestDistance;
}
if ($this->kernel instanceof SquaredDistance) {
$longestDistance = min($longestDistance, (2.0 * sqrt($queryNode->radius()) + sqrt($shortestDistance)) ** 2.0);
} else {
$longestDistance = (float) min($longestDistance, 2.0 * $queryNode->radius() + $shortestDistance);
}
$queryNode->setLongestDistance($longestDistance);
}
private function findNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, &$bestDistances): void {
$nodeDistance = $this->nodeDistance($queryNode, $referenceNode);
if ($nodeDistance > 0.0) {
// Calculate smallest possible bound (i.e., d(Q) ):
$currentBound = $queryNode->getLongestDistance();
// If node distance is greater than the longest possible edge in this node,
// prune this reference node
if ($nodeDistance > $currentBound) {
return;
}
}
if ($queryNode instanceof DualTreeClique && $referenceNode instanceof DualTreeClique) {
$this->updateNearestNeighbors($queryNode, $referenceNode, $k, $maxRange, $bestDistances);
return;
}
if ($queryNode instanceof DualTreeClique) {
foreach ($referenceNode->children() as $child) {
$this->findNearestNeighbors($queryNode, $child, $k, $maxRange, $bestDistances);
}
return;
}
if ($referenceNode instanceof DualTreeClique) {
$longestDistance = 0.0;
$queryLeft = $queryNode->left();
$queryRight = $queryNode->right();
$this->findNearestNeighbors($queryLeft, $referenceNode, $k, $maxRange, $bestDistances);
$this->findNearestNeighbors($queryRight, $referenceNode, $k, $maxRange, $bestDistances);
} else {
// --> if ($queryNode instanceof DualTreeBall && $referenceNode instanceof DualTreeBall)
$queryLeft = $queryNode->left();
$queryRight = $queryNode->right();
$referenceLeft = $referenceNode->left();
$referenceRight = $referenceNode->right();
// TODO: traverse closest neighbor nodes first
$this->findNearestNeighbors($queryLeft, $referenceLeft, $k, $maxRange, $bestDistances);
$this->findNearestNeighbors($queryLeft, $referenceRight, $k, $maxRange, $bestDistances);
$this->findNearestNeighbors($queryRight, $referenceLeft, $k, $maxRange, $bestDistances);
$this->findNearestNeighbors($queryRight, $referenceRight, $k, $maxRange, $bestDistances);
}
$longestLeft = $queryLeft->getLongestDistance();
$longestRight = $queryRight->getLongestDistance();
// TODO: min($longestLeft, $longestRight) + 2 * ($queryNode->radius()) <--- Can be made tighter by using the shortest distance from child.
if ($this->kernel instanceof SquaredDistance) {
$longestDistance = max($longestLeft, $longestRight);
$longestLeft = (sqrt($longestLeft) + 2.0 * (sqrt($queryNode->radius()) - sqrt($queryLeft->radius()))) ** 2.0;
$longestRight = (sqrt($longestRight) + 2.0 * (sqrt($queryNode->radius()) - sqrt($queryRight->radius()))) ** 2.0;
$longestDistance = (float) min($longestDistance, min($longestLeft, $longestRight), (sqrt(min($longestLeft, $longestRight)) + 2.0 * (sqrt($queryNode->radius()))) ** 2.0);
} else {
$longestDistance = max($longestLeft, $longestRight);
$longestLeft = $longestLeft + 2 * ($queryNode->radius() - $queryLeft->radius());
$longestRight = $longestRight + 2 * ($queryNode->radius() - $queryRight->radius());
$longestDistance = (float) min($longestDistance, min($longestLeft, $longestRight), min($longestLeft, $longestRight) + 2.0 * ($queryNode->radius()));
}
$queryNode->setLongestDistance($longestDistance);
return;
}
public function kNearestAll($k, float $maxRange = INF): void {
$this->coreNeighborDistances = [];
$allLabels = $this->dataset->labels();
$bestDistances = [];
foreach ($allLabels as $label) {
$bestDistances[$label] = $maxRange;
}
if ($this->root === null) {
return;
}
$treeRoot = $this->root;
$treeRoot->resetFullyConnectedStatus();
$treeRoot->resetLongestEdge();
$this->findNearestNeighbors($treeRoot, $treeRoot, $k, $maxRange, $bestDistances);
}
/**
* Precompute core distances for the current dataset to accelerate
* subsequent MRD queries. Optionally, utilize core distances that've
* been previously determined for (a subset of) the current dataset.
* Returns the updated core distances for future use.
*
* @internal
*
* @param array|null $oldCoreNeighbors
* @return array
*/
public function precalculateCoreDistances(?array $oldCoreNeighbors = null) {
if (empty($this->dataset)) {
throw new \Exception("Precalculation of core distances requested but dataset is empty. Call ->grow() first!");
}
$labels = $this->dataset->labels();
if ($oldCoreNeighbors !== null && !empty($oldCoreNeighbors) && count(reset($oldCoreNeighbors)) >= $this->sampleSize) {
// Determine the search radius for core distances based on the largest old
// core distance (points father than that cannot shorten the old core distances)
$largestOldCoreDistance = 0.0;
// Utilize old (possibly stale) core distance data
foreach ($oldCoreNeighbors as $label => $oldDistances) {
$coreDistance = (array_values($oldDistances))[$this->sampleSize - 1];
if ($coreDistance > $largestOldCoreDistance) {
$largestOldCoreDistance = $coreDistance;
}
$this->coreNeighborDistances[$label] = $oldDistances;
$this->coreDistances[$label] = $coreDistance;
}
$updatedOldCoreLabels = [];
// Don't recalculate core distances for the old labels
$labels = array_filter($labels, function ($label) use ($oldCoreNeighbors) {
return !isset($oldCoreNeighbors[$label]);
});
foreach ($labels as $label) {
[$neighborLabels, $neighborDistances] = $this->cachedRange($label, $largestOldCoreDistance);
// TODO: cachedRange may not return $this->sampleSize number of labels.
$this->coreNeighborDistances[$label] = array_combine(array_slice($neighborLabels, 0, $this->sampleSize), array_slice($neighborDistances, 0, $this->sampleSize));
$this->coreDistances[$label] = $neighborDistances[$this->sampleSize - 1];
// If one of the old vertices is within the update radius of this new vertex,
// check whether the old core distance needs to be updated.
foreach ($neighborLabels as $distanceKey => $neighborLabel) {
if (isset($oldCoreNeighbors[$neighborLabel])) {
$newDistance = $neighborDistances[$distanceKey];
if ($newDistance < $this->coreDistances[$neighborLabel]) {
$this->coreNeighborDistances[$neighborLabel][$label] = $newDistance;
$updatedOldCoreLabels[$neighborLabel] = true;
}
}
}
}
foreach (array_keys($updatedOldCoreLabels) as $label) {
asort($this->coreNeighborDistances[$label]);
$this->coreNeighborDistances[$label] = array_slice($this->coreNeighborDistances[$label], 0, $this->sampleSize, true);
$this->coreDistances[$label] = end($this->coreNeighborDistances[$label]);
}
} else { // $oldCoreNeighbors === null
$this->kNearestAll($this->sampleSize, INF);
foreach ($this->dataset->labels() as $label) {
$this->coreDistances[$label] = end($this->coreNeighborDistances[$label]);
}
}
return $this->coreNeighborDistances;
}
/**
* Inserts a new neighbor to core neighbors if the distance
* is greater than the current largest distance for the query label.
*
* Returns the updated core distance or INF if there are less than $this->sampleSize neighbors.
*
* @internal
*
* @param mixed $queryLabel
* @param mixed $referenceLabel
* @param float $distance
* @return float
*/
private function insertToCoreDistances($queryLabel, $referenceLabel, $distance): float {
// Update the core distances of the queryLabel
if (isset($this->coreDistances[$queryLabel])) {
if ($this->coreDistances[$queryLabel] > $distance) {
$this->coreNeighborDistances[$queryLabel][$referenceLabel] = $distance;
asort($this->coreNeighborDistances[$queryLabel]);
$this->coreNeighborDistances[$queryLabel] = array_slice($this->coreNeighborDistances[$queryLabel], 0, $this->sampleSize, true);
$this->coreDistances[$queryLabel] = end($this->coreNeighborDistances[$queryLabel]);
}
} else {
$this->coreNeighborDistances[$queryLabel][$referenceLabel] = $distance;
if (count($this->coreNeighborDistances[$queryLabel]) >= $this->sampleSize) {
asort($this->coreNeighborDistances[$queryLabel]);
$this->coreNeighborDistances[$queryLabel] = array_slice($this->coreNeighborDistances[$queryLabel], 0, $this->sampleSize, true);
$this->coreDistances[$queryLabel] = end($this->coreNeighborDistances[$queryLabel]);
}
}
// Update the core distances of the referenceLabel (this is not necessary, but *may* accelerate the algo slightly)
if (isset($this->coreDistances[$referenceLabel])) {
if ($this->coreDistances[$referenceLabel] > $distance) {
$this->coreNeighborDistances[$referenceLabel][$queryLabel] = $distance;
asort($this->coreNeighborDistances[$referenceLabel]);
$this->coreNeighborDistances[$referenceLabel] = array_slice($this->coreNeighborDistances[$referenceLabel], 0, $this->sampleSize, true);
$this->coreDistances[$referenceLabel] = end($this->coreNeighborDistances[$referenceLabel]);
}
} else {
$this->coreNeighborDistances[$referenceLabel][$queryLabel] = $distance;
if (count($this->coreNeighborDistances[$referenceLabel]) > $this->sampleSize) {
asort($this->coreNeighborDistances[$referenceLabel]);
$this->coreNeighborDistances[$referenceLabel] = array_slice($this->coreNeighborDistances[$referenceLabel], 0, $this->sampleSize, true);
$this->coreDistances[$referenceLabel] = end($this->coreNeighborDistances[$referenceLabel]);
}
}
return $this->coreDistances[$queryLabel] ?? INF;
}
/**
* Compute the mutual reachability distance between two vectors.
*
* @internal
*
* @param int|string $a
* @param int|string $b
* @return float
*/
public function computeMrd($a, array $a_vector, $b, array $b_vector): float {
$distance = $this->cachedComputeNative($a, $a_vector, $b, $b_vector);
return max($distance, $this->getCoreDistance($a), $this->getCoreDistance($b));
}
public function getCoreDistance($label): float {
if (!isset($this->coreDistances[$label])) {
[$labels, $distances] = $this->getCoreNeighbors($label);
$this->coreDistances[$label] = end($distances);
}
return $this->coreDistances[$label];
}
public function cachedComputeNative($a, array $a_vector, $b, array $b_vector, bool $storeNewCalculations = true): float {
if (isset($this->coreNeighborDistances[$a][$b])) {
return $this->coreNeighborDistances[$a][$b];
}
if (isset($this->coreNeighborDistances[$b][$a])) {
return $this->coreNeighborDistances[$b][$a];
}
if ($a < $b) {
$smallIndex = $a;
$largeIndex = $b;
} else {
$smallIndex = $b;
$largeIndex = $a;
}
if (!isset($this->nativeInterpointCache[$smallIndex][$largeIndex])) {
$distance = $this->kernel->compute($a_vector, $b_vector);
if ($storeNewCalculations) {
$this->nativeInterpointCache[$smallIndex][$largeIndex] = $distance;
}
return $distance;
}
return $this->nativeInterpointCache[$smallIndex][$largeIndex];
}
/**
* Run a n nearest neighbors search on a single label and return the neighbor labels, and distances in a 2-tuple
*
*
* @internal
*
* @param int|string $sampleLabel
* @param bool $useCachedValues
* @throws \OCA\Recognize\Vendor\Rubix\ML\Exceptions\InvalidArgumentException
* @return array{list<mixed>,list<float>}
*/
public function getCoreNeighbors($sampleLabel, bool $useCachedValues = true): array {
if ($useCachedValues && isset($this->coreNeighborDistances[$sampleLabel])) {
return [array_keys($this->coreNeighborDistances[$sampleLabel]), array_values($this->coreNeighborDistances[$sampleLabel])];
}
$sampleKey = array_search($sampleLabel, $this->dataset->labels());
$sample = $this->dataset->sample($sampleKey);
$squaredDistance = $this->kernel instanceof SquaredDistance;
/** @var list<DualTreeBall|DualTreeClique> **/
$stack = [$this->root];
$stackDistances = [0.0];
$radius = INF;
$labels = $distances = [];
while ($current = array_pop($stack)) {
$currentDistance = array_pop($stackDistances);
if ($currentDistance > $radius) {
continue;
}
if ($current instanceof DualTreeBall) {
foreach ($current->children() as $child) {
if ($child instanceof Hypersphere) {
$distance = $this->kernel->compute($sample, $child->center());
if ($squaredDistance) {
$distance = sqrt($distance);
$childRadius = sqrt($child->radius());
$distance = $distance - $childRadius;
$distance = abs($distance) * $distance;
} else {
$distance = $distance - $child->radius();
}
if ($distance < $radius) {
$stack[] = $child;
$stackDistances[] = $distance;
}
}
}
array_multisort($stackDistances, SORT_DESC, $stack);
} elseif ($current instanceof DualTreeClique) {
$dataset = $current->dataset();
$neighborLabels = $dataset->labels();
foreach ($dataset->samples() as $i => $neighbor) {
if ($neighborLabels[$i] === $sampleLabel) {
continue;
}
$distance = $this->cachedComputeNative($sampleLabel, $sample, $neighborLabels[$i], $neighbor);
if ($distance <= $radius) {
$labels[] = $neighborLabels[$i];
$distances[] = $distance;
}
}
if (count($labels) >= $this->sampleSize) {
array_multisort($distances, $labels);
$radius = $distances[$this->sampleSize - 1];
$labels = array_slice($labels, 0, $this->sampleSize);
$distances = array_slice($distances, 0, $this->sampleSize);
}
}
}
return [$labels, $distances];
}
/**
* Return all labels, and distances within a given radius of a sample.
*
*
* @internal
*
* @param int $sampleLabel
* @param float $radius
* @throws \OCA\Recognize\Vendor\Rubix\ML\Exceptions\InvalidArgumentException
* @throws \OCA\Recognize\Vendor\Rubix\ML\Exceptions\RuntimeException
* @return array{list<mixed>,list<float>}
*/
public function cachedRange($sampleLabel, float $radius): array {
$sampleKey = array_search($sampleLabel, $this->dataset->labels());
$sample = $this->dataset->sample($sampleKey);
$squaredDistance = $this->kernel instanceof SquaredDistance;
/** @var list<DualTreeBall|DualTreeClique> **/
$stack = [$this->root];
$labels = $distances = [];
while ($current = array_pop($stack)) {
if ($current instanceof DualTreeBall) {
foreach ($current->children() as $child) {
if ($child instanceof Hypersphere) {
$distance = $this->kernel->compute($sample, $child->center());
if ($squaredDistance) {
$distance = sqrt($distance);
$childRadius = sqrt($child->radius());
$minDistance = $distance - $childRadius;
$minDistance = abs($minDistance) * $minDistance;
$maxDistance = ($distance + $childRadius) ** 2.0;
} else {
$childRadius = $child->radius();
$minDistance = $distance - $childRadius;
$maxDistance = $distance + $childRadius;
}
if ($minDistance < $radius) {
if ($maxDistance < $radius && $child instanceof DualTreeBall) {
// The whole child is within the specified radius: greedily add all sub-children recursively to the stack
$subStack = [$child];
while ($subStackCurrent = array_pop($subStack)) {
foreach ($subStackCurrent->children() as $subChild) {
if ($subChild instanceof DualTreeClique) {
$stack[] = $subChild;
} else {
$subStack[] = $subChild;
}
}
}
} else {
$stack[] = $child;
}
}
}
}
} elseif ($current instanceof DualTreeClique) {
$dataset = $current->dataset();
$neighborLabels = $dataset->labels();
foreach ($dataset->samples() as $i => $neighbor) {
$distance = $this->cachedComputeNative($sampleLabel, $sample, $neighborLabels[$i], $neighbor);
if ($distance <= $radius) {
$labels[] = $neighborLabels[$i];
$distances[] = $distance;
}
}
}
}
array_multisort($distances, $labels);
return [$labels, $distances];
}
/**
* Insert a root node and recursively split the dataset until a terminating
* condition is met. This also sets the dataset that will be used to calculate
* core distances. Previously calculated core distances will be stored/used
* despite calling grow, unless precalculateCoreDistances() is called again.
*
* @internal
*
* @param \OCA\Recognize\Vendor\Rubix\ML\Datasets\Labeled $dataset
* @throws \OCA\Recognize\Vendor\Rubix\ML\Exceptions\InvalidArgumentException
*/
public function grow(Labeled $dataset): void {
$this->dataset = $dataset;
$this->root = DualTreeBall::split($dataset, $this->kernel);
$stack = [$this->root];
while ($current = array_pop($stack)) {
[$left, $right] = $current->subsets();
$current->cleanup();
if ($left->numSamples() > $this->maxLeafSize) {
$node = DualTreeBall::split($left, $this->kernel);
[$subLeft, $subRight] = $node->subsets();
if ($subLeft->empty() || $subRight->empty()) {
$current->attachLeft(DualTreeClique::terminate($left, $this->kernel));
} else {
$current->attachLeft($node);
$stack[] = $node;
}
} else {
$current->attachLeft(DualTreeClique::terminate($left, $this->kernel));
}
if ($right->numSamples() > $this->maxLeafSize) {
$node = DualTreeBall::split($right, $this->kernel);
[$subLeft, $subRight] = $node->subsets();
if ($node->isPoint() || $subLeft->empty() || $subRight->empty()) {
$current->attachRight(DualTreeClique::terminate($right, $this->kernel));
} else {
$current->attachRight($node);
$stack[] = $node;
}
} else {
$current->attachRight(DualTreeClique::terminate($right, $this->kernel));
}
}
}
}