# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import logging
from abc import abstractmethod
from typing import Any, Generator, Literal, Optional
import neo4j
from pydantic import validate_call
from neo4j_graphrag.experimental.components.filename_collision_handler import (
FilenameCollisionHandler,
)
from neo4j_graphrag.experimental.components.parquet_formatter import (
Neo4jGraphParquetFormatter,
)
from neo4j_graphrag.experimental.components.parquet_output import (
ParquetOutputDestination,
)
from neo4j_graphrag.experimental.components.types import (
LexicalGraphConfig,
Neo4jGraph,
Neo4jNode,
Neo4jRelationship,
)
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.neo4j_queries import (
upsert_node_query,
upsert_relationship_query,
db_cleaning_query,
)
from neo4j_graphrag.utils.version_utils import (
get_version,
is_version_5_23_or_above,
is_version_5_24_or_above,
)
from neo4j_graphrag.utils import driver_config
logger = logging.getLogger(__name__)
def _build_columns_from_schema(
schema: Any, primary_key_names: list[str]
) -> list[dict[str, Any]]:
"""Build a list of column dicts (name, type, is_primary_key) from a PyArrow schema."""
columns: list[dict[str, Any]] = []
for i in range(len(schema)):
field = schema.field(i)
type_info = Neo4jGraphParquetFormatter.pyarrow_type_to_type_info(field.type)
columns.append(
{
"name": field.name,
"type": type_info.source_type,
"is_primary_key": field.name in primary_key_names,
}
)
return columns
def batched(rows: list[Any], batch_size: int) -> Generator[list[Any], None, None]:
index = 0
for i in range(0, len(rows), batch_size):
start = i
end = min(start + batch_size, len(rows))
batch = rows[start:end]
yield batch
index += 1
def _graph_stats(
graph: Neo4jGraph,
nodes_per_label: Optional[dict[str, int]] = None,
rel_per_type: Optional[dict[str, int]] = None,
input_files_count: int = 0,
input_files_total_size_bytes: int = 0,
) -> dict[str, Any]:
"""Build the statistics dict for writer metadata.
Schema:
node_count, relationship_count, nodes_per_label, rel_per_type,
input_files_count, input_files_total_size_bytes.
"""
if nodes_per_label is None:
nodes_per_label = {}
for node in graph.nodes:
nodes_per_label[node.label] = nodes_per_label.get(node.label, 0) + 1
if rel_per_type is None:
rel_per_type = {}
for rel in graph.relationships:
rel_per_type[rel.type] = rel_per_type.get(rel.type, 0) + 1
return {
"node_count": len(graph.nodes),
"relationship_count": len(graph.relationships),
"nodes_per_label": nodes_per_label,
"rel_per_type": rel_per_type,
"input_files_count": input_files_count,
"input_files_total_size_bytes": input_files_total_size_bytes,
}
[docs]
class KGWriterModel(DataModel):
"""Data model for the output of the Knowledge Graph writer.
Attributes:
status: Whether the write operation was successful ("SUCCESS" or "FAILURE").
metadata: Optional dict. When status is SUCCESS, contains at least:
- "statistics": dict with node_count, relationship_count, nodes_per_label,
rel_per_type, input_files_count, input_files_total_size_bytes.
- "files": list of file descriptors with file_path, etc. (ParquetWriter).
"""
status: Literal["SUCCESS", "FAILURE"]
metadata: Optional[dict[str, Any]] = None
[docs]
class KGWriter(Component):
"""Abstract class used to write a knowledge graph to a data store."""
[docs]
@abstractmethod
@validate_call
async def run(
self,
graph: Neo4jGraph,
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
) -> KGWriterModel:
"""
Writes the graph to a data store.
Args:
graph (Neo4jGraph): The knowledge graph to write to the data store.
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types in the lexical graph.
"""
pass
[docs]
class Neo4jWriter(KGWriter):
"""Writes a knowledge graph to a Neo4j database.
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.
Example:
.. code-block:: python
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.pipeline import Pipeline
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = GraphDatabase.driver(URI, auth=AUTH)
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
pipeline = Pipeline()
pipeline.add_component(writer, "writer")
"""
def __init__(
self,
driver: neo4j.Driver,
neo4j_database: Optional[str] = None,
batch_size: int = 1000,
clean_db: bool = True,
):
self.driver = driver_config.override_user_agent(driver)
self.neo4j_database = neo4j_database
self.batch_size = batch_size
self._clean_db = clean_db
version_tuple, _, _ = get_version(self.driver, self.neo4j_database)
self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple)
self.is_version_5_24_or_above = is_version_5_24_or_above(version_tuple)
def _db_setup(self) -> None:
self.driver.execute_query("""
CREATE INDEX __entity__tmp_internal_id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.__tmp_internal_id)
""")
@staticmethod
def _nodes_to_rows(
nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig
) -> list[dict[str, Any]]:
rows = []
for node in nodes:
labels = [node.label]
if node.label not in lexical_graph_config.lexical_graph_node_labels:
labels.append("__Entity__")
row = node.model_dump()
row["labels"] = labels
rows.append(row)
return rows
def _upsert_nodes(
self, nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig
) -> None:
"""Upserts a batch of nodes into the Neo4j database.
Args:
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
"""
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
query = upsert_node_query(
support_variable_scope_clause=self.is_version_5_23_or_above,
support_dynamic_labels=self.is_version_5_24_or_above,
)
self.driver.execute_query(
query,
parameters_=parameters,
database_=self.neo4j_database,
)
return None
@staticmethod
def _relationships_to_rows(
relationships: list[Neo4jRelationship],
) -> list[dict[str, Any]]:
return [relationship.model_dump() for relationship in relationships]
def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
"""Upserts a batch of relationships into the Neo4j database.
Args:
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
"""
parameters = {"rows": self._relationships_to_rows(rels)}
query = upsert_relationship_query(
support_variable_scope_clause=self.is_version_5_23_or_above
)
self.driver.execute_query(
query,
parameters_=parameters,
database_=self.neo4j_database,
)
def _db_cleaning(self) -> None:
query = db_cleaning_query(
support_variable_scope_clause=self.is_version_5_23_or_above,
batch_size=self.batch_size,
)
with self.driver.session() as session:
session.run(query)
[docs]
@validate_call
async def run(
self,
graph: Neo4jGraph,
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
) -> KGWriterModel:
"""Upserts a knowledge graph into a Neo4j database.
Args:
graph (Neo4jGraph): The knowledge graph to upsert into the database.
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
"""
try:
self._db_setup()
for batch in batched(graph.nodes, self.batch_size):
self._upsert_nodes(batch, lexical_graph_config)
for batch in batched(graph.relationships, self.batch_size):
self._upsert_relationships(batch)
if self._clean_db:
self._db_cleaning()
return KGWriterModel(
status="SUCCESS",
metadata={
"statistics": _graph_stats(graph),
"files": [],
},
)
except neo4j.exceptions.ClientError as e:
logger.exception(e)
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})
class ParquetWriter(KGWriter):
"""Writes a knowledge graph to Parquet files using Neo4jGraphParquetFormatter.
Writes one Parquet file per node label and one per (head_label, relationship_type, tail_label)
to the given destinations, e.g. ``Person.parquet``, ``Person_KNOWS_Person.parquet``.
Args:
nodes_dest (ParquetOutputDestination): Destination for node Parquet files.
relationships_dest (ParquetOutputDestination): Destination for relationship Parquet files.
collision_handler (FilenameCollisionHandler): Handler for resolving filename collisions.
prefix (str): Optional filename prefix for all written files. Defaults to "".
Example:
.. code-block:: python
from neo4j_graphrag.experimental.components.filename_collision_handler import FilenameCollisionHandler
from neo4j_graphrag.experimental.components.kg_writer import ParquetWriter
from neo4j_graphrag.experimental.components.parquet_output import ParquetOutputDestination
from neo4j_graphrag.experimental.pipeline import Pipeline
# Provide your own implementation of ParquetOutputDestination (local, GCS, S3, etc.)
nodes_dest: ParquetOutputDestination = ...
relationships_dest: ParquetOutputDestination = ...
writer = ParquetWriter(
nodes_dest=nodes_dest,
relationships_dest=relationships_dest,
collision_handler=FilenameCollisionHandler(),
)
pipeline = Pipeline()
pipeline.add_component(writer, "writer")
"""
def __init__(
self,
nodes_dest: ParquetOutputDestination,
relationships_dest: ParquetOutputDestination,
collision_handler: FilenameCollisionHandler,
prefix: str = "",
) -> None:
self.nodes_dest = nodes_dest
self.relationships_dest = relationships_dest
self.collision_handler = collision_handler
self.prefix = prefix
@validate_call
async def run(
self,
graph: Neo4jGraph,
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
schema: Optional[dict[str, Any]] = None,
) -> KGWriterModel:
"""Write the knowledge graph to Parquet files via Neo4jGraphParquetFormatter.
Args:
graph (Neo4jGraph): The knowledge graph to write.
lexical_graph_config (LexicalGraphConfig): Used by the formatter for
lexical graph labels (e.g. __Entity__) and key properties.
schema (Optional[dict[str, Any]]): Optional GraphSchema as a dictionary for
uniqueness constraints and key properties. If not provided, ``__id__`` is used.
"""
try:
formatter = Neo4jGraphParquetFormatter(schema=schema)
data, file_metadata, stats = formatter.format_graph(
graph, lexical_graph_config, prefix=self.prefix
)
meta_by_filename: dict[str, Any] = {m.filename: m for m in file_metadata}
files: list[dict[str, Any]] = []
node_label_to_source_name: dict[str, str] = {}
base_nodes = self.nodes_dest.output_path.rstrip("/")
for filename, content in data["nodes"].items():
meta = meta_by_filename[filename]
unique_filename = self.collision_handler.get_unique_filename(
filename, self.nodes_dest.output_path
)
await self.nodes_dest.write(content, unique_filename)
file_path = os.path.join(base_nodes, unique_filename)
resolved_stem = (
unique_filename[:-8]
if unique_filename.endswith(".parquet")
else unique_filename
)
if meta.node_label is not None:
node_label_to_source_name[meta.node_label] = resolved_stem
columns = _build_columns_from_schema(
meta.schema,
meta.key_properties or [],
)
name = meta.node_label or (
meta.labels[0] if meta.labels else resolved_stem
)
files.append(
{
"name": name,
"file_path": file_path,
"columns": columns,
"is_node": True,
"labels": meta.labels or [],
}
)
base_rel = self.relationships_dest.output_path.rstrip("/")
for filename, content in data["relationships"].items():
meta = meta_by_filename[filename]
unique_filename = self.collision_handler.get_unique_filename(
filename, self.relationships_dest.output_path
)
await self.relationships_dest.write(content, unique_filename)
file_path = os.path.join(base_rel, unique_filename)
start_node_source = node_label_to_source_name.get(
meta.relationship_head or "", meta.relationship_head or ""
)
end_node_source = node_label_to_source_name.get(
meta.relationship_tail or "", meta.relationship_tail or ""
)
columns = _build_columns_from_schema(
meta.schema,
["from", "to"],
)
rel_name = (
f"{meta.relationship_head}_{meta.relationship_type}_{meta.relationship_tail}"
if meta.relationship_head
and meta.relationship_type
and meta.relationship_tail
else unique_filename[:-8]
if unique_filename.endswith(".parquet")
else unique_filename
)
files.append(
{
"name": rel_name,
"file_path": file_path,
"columns": columns,
"is_node": False,
"relationship_type": meta.relationship_type,
"start_node_source": start_node_source,
"start_node_primary_keys": meta.head_node_key_properties
or ["__id__"],
"end_node_source": end_node_source,
"end_node_primary_keys": meta.tail_node_key_properties
or ["__id__"],
}
)
logger.info(
"Wrote %d node files and %d relationship files",
len(data["nodes"]),
len(data["relationships"]),
)
statistics = _graph_stats(
graph,
nodes_per_label=stats["nodes_per_label"],
rel_per_type=stats["rel_per_type"],
input_files_count=0,
input_files_total_size_bytes=0,
)
return KGWriterModel(
status="SUCCESS",
metadata={
"statistics": statistics,
"files": files,
},
)
except Exception as e:
logger.exception(e)
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})