Source code for neo4j_graphrag.retrievers.tools_retriever

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

from typing import Any, List, Optional, Sequence

import neo4j

from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RawSearchResult
from neo4j_graphrag.tool import Tool
from neo4j_graphrag.types import LLMMessage


[docs] class ToolsRetriever(Retriever): """A retriever that uses an LLM to select appropriate tools for retrieval based on user input. This retriever takes an LLM instance and a list of Tool objects as input. When a search is performed, it uses the LLM to analyze the query and determine which tools (if any) should be used to retrieve the necessary data. It then executes the selected tools and returns the combined results. Example: .. code-block:: python import neo4j from neo4j_graphrag.retrievers import ToolsRetriever, VectorRetriever, Text2CypherRetriever from neo4j_graphrag.llm import OpenAILLM from neo4j_graphrag.embeddings import OpenAIEmbeddings driver = neo4j.GraphDatabase.driver("neo4j://localhost:7687", auth=("neo4j", "password")) llm = OpenAILLM(model_name="gpt-4", api_key="your-api-key") embedder = OpenAIEmbeddings(model="text-embedding-3-small", api_key="your-api-key") # Create retrievers and convert them to tools vector_retriever = VectorRetriever(driver, "vector-index", embedder) vector_tool = vector_retriever.convert_to_tool( name="vector_search", description="Search for documents using semantic similarity" ) text2cypher_retriever = Text2CypherRetriever(driver, llm) cypher_tool = text2cypher_retriever.convert_to_tool( name="cypher_search", description="Generate and execute Cypher queries for structured data retrieval" ) # Initialize ToolsRetriever with the tools tools_retriever = ToolsRetriever( driver=driver, llm=llm, tools=[vector_tool, cypher_tool] ) # Use the retriever - the LLM will automatically select appropriate tools result = tools_retriever.search("What movies did Tom Hanks act in and what are their plots?") Args: driver (neo4j.Driver): The Neo4j Python driver. llm (LLMInterface): LLM instance used to select and coordinate tool execution. tools (Sequence[Tool]): List of tools available for selection. All tools must have unique names. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default). system_instruction (Optional[str]): Custom system instruction for the LLM to guide tool selection. If not provided, a default instruction is used. Raises: ValueError: If duplicate tool names are found in the tools list. """ # Disable Neo4j version verification since this retriever doesn't directly interact with Neo4j VERIFY_NEO4J_VERSION = False def __init__( self, driver: neo4j.Driver, llm: LLMInterface, tools: Sequence[Tool], neo4j_database: Optional[str] = None, system_instruction: Optional[str] = None, ): """Initialize the ToolsRetriever with an LLM and a list of tools.""" super().__init__(driver, neo4j_database) self.llm = llm self._tools = list(tools) # Make a copy to allow modification self._validate_tool_names() self.system_instruction = ( system_instruction or self._get_default_system_instruction() ) def _validate_tool_names(self) -> None: """Validate that all tool names are unique.""" tool_names = [tool.get_name() for tool in self._tools] duplicate_names = [ name for name in set(tool_names) if tool_names.count(name) > 1 ] if duplicate_names: raise ValueError( f"Duplicate tool names found: {duplicate_names}. " "All tools must have unique names for proper LLM tool selection." ) def _get_default_system_instruction(self) -> str: """Get the default system instruction for the LLM.""" return ( "You are an assistant that helps select the most appropriate tools to retrieve information " "based on the user's query. Analyze the query carefully and determine which tools, if any, " "would be most helpful in retrieving the relevant information. You can select multiple tools " "if necessary, or none if no tools are appropriate for the query." ) def get_search_results( self, query_text: str, message_history: Optional[List[LLMMessage]] = None, **kwargs: Any, ) -> RawSearchResult: """Use the LLM to select and execute appropriate tools based on the query. Args: query_text (str): The user's query text. message_history (Optional[Union[List[LLMMessage], MessageHistory]], optional): Previous conversation history. Defaults to None. **kwargs (Any): Additional arguments passed to the tool execution. Returns: RawSearchResult: The combined results from the executed tools. """ if not self._tools: # No tools available, return empty result return RawSearchResult( records=[], metadata={"query": query_text, "error": "No tools available"}, ) try: # Use the LLM to select appropriate tools tool_call_response = self.llm.invoke_with_tools( input=query_text, tools=self._tools, message_history=message_history, system_instruction=self.system_instruction, ) # If no tool calls were made, return empty result if not tool_call_response.tool_calls: return RawSearchResult( records=[], metadata={ "query": query_text, "llm_response": tool_call_response.content, "tools_selected": [], }, ) # Execute each selected tool and collect results all_records = [] tools_selected = [] for tool_call in tool_call_response.tool_calls: tool_name = tool_call.name tools_selected.append(tool_name) # Find the tool by name selected_tool = next( (tool for tool in self._tools if tool.get_name() == tool_name), None ) if selected_tool is not None: # Extract arguments from the tool call tool_args = tool_call.arguments or {} # Execute the tool with the provided arguments tool_result = selected_tool.execute(**tool_args) # Handle different tool result types if hasattr(tool_result, "items") and not callable( getattr(tool_result, "items") ): # RetrieverResult from formatted retriever tools for item in tool_result.items: record = neo4j.Record( { "content": item.content, "tool_name": tool_name, "metadata": { **(item.metadata or {}), "tool": tool_name, }, } ) all_records.append(record) elif hasattr(tool_result, "records"): # RawSearchResult from raw retriever tools (legacy) for record in tool_result.records: # Wrap raw records with tool attribution attributed_record = neo4j.Record( { "content": str(record), "tool_name": tool_name, "metadata": { "original_record": dict(record), "tool": tool_name, }, } ) all_records.append(attributed_record) else: # Handle non-retriever tools or simple return values record = neo4j.Record( { "content": str(tool_result), "tool_name": tool_name, "metadata": {"tool": tool_name}, } ) all_records.append(record) # Combine metadata from all tool calls combined_metadata = { "query": query_text, "llm_response": tool_call_response.content, "tools_selected": tools_selected, } return RawSearchResult(records=all_records, metadata=combined_metadata) except Exception as e: # Handle any errors during tool selection or execution return RawSearchResult( records=[], metadata={ "query": query_text, "error": str(e), "error_type": type(e).__name__, }, )