mirror of
https://gitlab.com/gitlab-org/gitlab-foss.git
synced 2025-08-01 16:04:19 +00:00
114 lines
4.0 KiB
Ruby
114 lines
4.0 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module ActiveContext
|
|
module Preprocessors
|
|
module Embeddings
|
|
extend ActiveSupport::Concern
|
|
|
|
IndexingError = Class.new(StandardError)
|
|
|
|
# We cannot inherit from StandardError because this is rescued in `with_batch_handling`
|
|
# This type of error should always fail since we want it to be caught immediately
|
|
# TODO: when https://gitlab.com/gitlab-org/gitlab/-/issues/552302 is completed,
|
|
# we can change this to inherit from StandardError
|
|
EmbeddingsVersionError = Class.new(Exception) # rubocop: disable Lint/InheritException -- see comment above
|
|
|
|
class_methods do
|
|
def apply_embeddings(
|
|
refs:,
|
|
unit_primitive:,
|
|
content_field: :content,
|
|
content_method: nil,
|
|
remove_content: true
|
|
)
|
|
with_batch_handling(refs) do
|
|
docs_to_process = refs.flat_map do |ref|
|
|
next [] unless ref.embedding_versions.any?
|
|
|
|
initialize_documents!(ref, content_method, content_field)
|
|
|
|
# Create a mapping of reference, document, and embedding versions for processing
|
|
ref.documents.map do |doc|
|
|
{
|
|
ref: ref,
|
|
doc: doc,
|
|
versions: ref.embedding_versions
|
|
}
|
|
end
|
|
end
|
|
|
|
# Group documents by their embedding version configuration
|
|
# This allows processing similar documents together with the same embedding model
|
|
version_groups = docs_to_process.group_by { |item| item[:versions].map { |v| [v[:field], v[:model]] }.sort }
|
|
|
|
version_groups.each_value do |items|
|
|
versions = items.first[:versions]
|
|
contents = items.map { |item| item[:doc][content_field] }
|
|
|
|
embeddings_by_version = generate_embeddings_for_each_version(versions: versions, contents: contents,
|
|
unit_primitive: unit_primitive)
|
|
|
|
# Apply the generated embeddings back to each document
|
|
items.each.with_index do |item, index|
|
|
versions.each do |version|
|
|
item[:doc][version[:field]] = embeddings_by_version[version[:field]][index]
|
|
end
|
|
|
|
item[:doc].delete(content_field) if remove_content
|
|
end
|
|
end
|
|
|
|
refs
|
|
end
|
|
end
|
|
|
|
private
|
|
|
|
# Initializes the documents for a reference if they don't exist
|
|
# and populates the content field if a content_method is provided
|
|
def initialize_documents!(ref, content_method, content_field)
|
|
return unless content_method && ref.respond_to?(content_method)
|
|
|
|
ref.documents << {} if ref.documents.empty?
|
|
|
|
ref.documents.each do |doc|
|
|
next if doc.key?(content_field)
|
|
|
|
doc[content_field] = ref.send(content_method) # rubocop: disable GitlabSecurity/PublicSend -- method is defined elsewhere
|
|
end
|
|
end
|
|
|
|
def generate_embeddings_for_each_version(versions:, contents:, unit_primitive:)
|
|
versions.each_with_object({}) do |version, embeddings_by_version|
|
|
klass = embeddings_class(version)
|
|
|
|
embedding = klass.generate_embeddings(
|
|
contents,
|
|
model: version[:model],
|
|
unit_primitive: unit_primitive,
|
|
batch_size: version[:batch_size]
|
|
)
|
|
embeddings_by_version[version[:field]] = embedding
|
|
end
|
|
end
|
|
|
|
def embeddings_class(embeddings_version)
|
|
klass = embeddings_version[:class]
|
|
field = embeddings_version[:field]
|
|
|
|
raise EmbeddingsVersionError, "No `class` specified for model version `#{field}`." if klass.nil?
|
|
|
|
unless klass <= ActiveContext::Embeddings
|
|
raise(
|
|
EmbeddingsVersionError,
|
|
"Specified class for model version `#{field}` must inherit from `#{ActiveContext::Embeddings}`."
|
|
)
|
|
end
|
|
|
|
klass
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|