Source code for neo4j_graphrag.experimental.components.embedder

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from pydantic import validate_call
import asyncio

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
from neo4j_graphrag.experimental.pipeline.component import Component


[docs] class TextChunkEmbedder(Component): """Component for creating embeddings from text chunks. Args: embedder (Embedder): The embedder to use to create the embeddings. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the embedder. Default is 5. Example: .. code-block:: python from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.pipeline import Pipeline embedder = OpenAIEmbeddings(model="text-embedding-3-large") chunk_embedder = TextChunkEmbedder(embedder) pipeline = Pipeline() pipeline.add_component(chunk_embedder, "chunk_embedder") """ def __init__(self, embedder: Embedder, max_concurrency: int = 5): self._embedder = embedder self.max_concurrency = max_concurrency def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: """Embed a single text chunk. Args: text_chunk (TextChunk): The text chunk to embed. Returns: TextChunk: The text chunk with an added "embedding" key in its metadata containing the embeddings of the text chunk's text. """ embedding = self._embedder.embed_query(text_chunk.text) metadata = text_chunk.metadata if text_chunk.metadata else {} metadata["embedding"] = embedding return TextChunk( text=text_chunk.text, index=text_chunk.index, metadata=metadata, uid=text_chunk.uid, ) async def _async_embed_chunk( self, sem: asyncio.Semaphore, text_chunk: TextChunk ) -> TextChunk: """Asynchronously embed a single text chunk. Args: text_chunk (TextChunk): The text chunk to embed. sem (asyncio.Semaphore): Semaphore to limit concurrency. Returns: TextChunk: The text chunk with an added "embedding" key in its metadata containing the embeddings of the text chunk's text. """ async with sem: embedding = await self._embedder.async_embed_query(text_chunk.text) metadata = text_chunk.metadata if text_chunk.metadata else {} metadata["embedding"] = embedding return TextChunk( text=text_chunk.text, index=text_chunk.index, metadata=metadata, uid=text_chunk.uid, )
[docs] @validate_call async def run(self, text_chunks: TextChunks) -> TextChunks: """Embed a list of text chunks. Args: text_chunks (TextChunks): The text chunks to embed. Returns: TextChunks: The input text chunks with each one having an added embedding. """ sem = asyncio.Semaphore(self.max_concurrency) tasks = [ self._async_embed_chunk(sem, text_chunk) for text_chunk in text_chunks.chunks ] chunks: list[TextChunk] = list(await asyncio.gather(*tasks)) return TextChunks(chunks=chunks)