diff --git a/examples/question_answering/graphrag_pipeline.py b/examples/question_answering/graphrag_pipeline.py new file mode 100644 index 000000000..2fb129c89 --- /dev/null +++ b/examples/question_answering/graphrag_pipeline.py @@ -0,0 +1,70 @@ +import asyncio + +import neo4j +from neo4j_graphrag.embeddings import OpenAIEmbeddings +from neo4j_graphrag.experimental.components.rag.generate import Generator +from neo4j_graphrag.experimental.components.rag.prompt_builder import PromptBuilder +from neo4j_graphrag.experimental.components.rag.retrievers import RetrieverWrapper +from neo4j_graphrag.experimental.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +from neo4j_graphrag.generation import RagTemplate +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.retrievers import VectorRetriever + +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" +INDEX_NAME = "moviePlotsEmbedding" + + +async def main() -> PipelineResult: + pipeline = Pipeline() + driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) + llm = OpenAILLM(model_name="gpt-4o") + embedder = OpenAIEmbeddings() + retriever = VectorRetriever( + driver, + index_name=INDEX_NAME, + neo4j_database=DATABASE, + embedder=embedder, + ) + pipeline.add_component(RetrieverWrapper(retriever), "retriever") + pipeline.add_component(PromptBuilder(RagTemplate()), "prompt") + pipeline.add_component(Generator(llm), "generate") + + pipeline.connect( + "retriever", + "prompt", + { + "context": "retriever.result", + }, + ) + pipeline.connect( + "prompt", + "generate", + { + "prompt": "prompt.prompt", + }, + ) + + query = "show me a movie with cats" + res = await pipeline.run( + { + "retriever": {"query_text": query}, + "prompt": {"query_text": query, "examples": ""}, + } + ) + + driver.close() + await llm.async_client.close() + + # context_result = await pipeline.store.get_result_for_component( + # res.run_id, "retriever" + # ) + # context = context_result.get("result") + # res.result["context"] = context + return res + + +if __name__ == "__main__": + print(asyncio.run(main())) diff --git a/examples/question_answering/graphrag_simple_pipeline.py b/examples/question_answering/graphrag_simple_pipeline.py new file mode 100644 index 000000000..8c46492f5 --- /dev/null +++ b/examples/question_answering/graphrag_simple_pipeline.py @@ -0,0 +1,26 @@ +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner + +if __name__ == "__main__": + import asyncio + import os + + os.environ["NEO4J_URI"] = "neo4j+s://demo.neo4jlabs.com" + os.environ["NEO4J_USER"] = "recommendations" + os.environ["NEO4J_PASSWORD"] = "recommendations" + + runner = PipelineRunner.from_config_file( + "examples/question_answering/simple_rag_pipeline_config.json" + ) + print( + asyncio.run( + runner.run( + dict( + query_text="show me a movie about cats", + retriever_config={ + "top_k": 2, + }, + # return_context=True, + ) + ) + ) + ) diff --git a/examples/question_answering/simple_rag_pipeline_config.json b/examples/question_answering/simple_rag_pipeline_config.json new file mode 100644 index 000000000..a3b463e24 --- /dev/null +++ b/examples/question_answering/simple_rag_pipeline_config.json @@ -0,0 +1,57 @@ +{ + "version_": "1", + "template_": "SimpleRAGPipeline", + "neo4j_config": { + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } + } + }, + "llm_config": { + "class_": "OpenAILLM", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + }, + "model_name": "gpt-4o", + "model_params": { + "temperature": 0, + "max_tokens": 2000 + } + } + }, + "embedder_config": { + "class_": "OpenAIEmbeddings", + "params_": { + "api_key": { + "resolver_": "ENV", + "var_": "OPENAI_API_KEY" + } + } + }, + "retriever": { + "class_": "VectorRetriever", + "params_": { + "driver": { + "resolver_": "CONFIG_KEY", + "key_": "neo4j_config.default" + }, + "index_name": "moviePlotsEmbedding", + "embedder": { + "resolver_": "CONFIG_KEY", + "key_": "embedder_config.default" + } + } + } +} diff --git a/src/neo4j_graphrag/experimental/components/rag/__init__.py b/src/neo4j_graphrag/experimental/components/rag/__init__.py new file mode 100644 index 000000000..c0199c144 --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/rag/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/neo4j_graphrag/experimental/components/rag/generate.py b/src/neo4j_graphrag/experimental/components/rag/generate.py new file mode 100644 index 000000000..9b5e45611 --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/rag/generate.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from neo4j_graphrag.experimental.pipeline import Component, DataModel +from neo4j_graphrag.llm import LLMInterface + + +class GenerationResult(DataModel): + content: str + + +class Generator(Component): + def __init__(self, llm: LLMInterface) -> None: + self.llm = llm + + async def run(self, prompt: str) -> GenerationResult: + llm_response = await self.llm.ainvoke(prompt) + return GenerationResult( + content=llm_response.content, + ) diff --git a/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py b/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py new file mode 100644 index 000000000..5c69394e1 --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/rag/prompt_builder.py @@ -0,0 +1,33 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from neo4j_graphrag.experimental.pipeline import Component, DataModel +from neo4j_graphrag.generation import PromptTemplate + +# class PromptData(DataModel): +# inputs: dict[str, Any] + + +class PromptResult(DataModel): + prompt: str + + +class PromptBuilder(Component): + def __init__(self, template: PromptTemplate): + self.template = template + + async def run(self, **kwargs: Any) -> PromptResult: + return PromptResult(prompt=self.template.format(**kwargs)) diff --git a/src/neo4j_graphrag/experimental/components/rag/retrievers.py b/src/neo4j_graphrag/experimental/components/rag/retrievers.py new file mode 100644 index 000000000..c83c263bc --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/rag/retrievers.py @@ -0,0 +1,33 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from neo4j_graphrag.experimental.pipeline import Component, DataModel +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RetrieverResult + + +class RetrieverWrapperResult(DataModel): + result: RetrieverResult + + +class RetrieverWrapper(Component): + def __init__(self, retriever: Retriever): + self.retriever = retriever + + async def run(self, **kwargs: Any) -> RetrieverWrapperResult: + return RetrieverWrapperResult( + result=self.retriever.search(**kwargs), + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0c..4e0fcd9d6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -45,6 +45,9 @@ def __new__( for param in sig.parameters.values() if param.name not in ("self", "kwargs") } + attrs["anonymous_input_allowed"] = any( + param.name == "kwargs" for param in sig.parameters.values() + ) # extract returned fields from the run method return type hint return_model = get_type_hints(run_method).get("return") if return_model is None: @@ -76,6 +79,7 @@ class Component(abc.ABC, metaclass=ComponentMeta): # DO NOT CHANGE component_inputs: dict[str, dict[str, str | bool]] component_outputs: dict[str, dict[str, str | bool]] + anonymous_input_allowed: bool @abc.abstractmethod async def run(self, *args: Any, **kwargs: Any) -> DataModel: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index c1bef9fee..e8408f0e5 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -45,9 +45,13 @@ from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import ( SimpleKGPipelineConfig, ) +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_rag_pipeline import ( + SimpleRAGPipelineConfig, +) from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition +from neo4j_graphrag.generation.types import RagResultModel from neo4j_graphrag.utils.logging import prettify logger = logging.getLogger(__name__) @@ -68,6 +72,7 @@ class PipelineConfigWrapper(BaseModel): config: Union[ Annotated[PipelineConfig, Tag(PipelineType.NONE)], Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], + Annotated[SimpleRAGPipelineConfig, Tag(PipelineType.SIMPLE_RAG_PIPELINE)], ] = Field(discriminator=Discriminator(_get_discriminator_value)) def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: @@ -136,3 +141,18 @@ async def close(self) -> None: logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)") if self.config: await self.config.close() + + +class RagPipelineRunner(PipelineRunner): + async def search(self, **kwargs: Any) -> RagResultModel: + result = await self.run(kwargs) + context = None + if kwargs.get("return_context"): + context = await self.pipeline.store.get_result_for_component( + result.run_id, "retriever" + ) + context = context.get("result") + return RagResultModel( + answer=result.result["generator"]["content"], + retriever_result=context, + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline.py new file mode 100644 index 000000000..76a2e86d8 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_rag_pipeline.py @@ -0,0 +1,116 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, ClassVar, Literal, Optional, Union, cast + +from pydantic import ConfigDict, RootModel + +from neo4j_graphrag.experimental.components.rag.generate import Generator +from neo4j_graphrag.experimental.components.rag.prompt_builder import PromptBuilder +from neo4j_graphrag.experimental.components.rag.retrievers import RetrieverWrapper +from neo4j_graphrag.experimental.pipeline.config.object_config import ObjectConfig +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import ( + TemplatePipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.types import ConnectionDefinition +from neo4j_graphrag.generation import RagTemplate +from neo4j_graphrag.retrievers.base import Retriever + + +class RetrieverConfig(ObjectConfig[RetrieverWrapper]): + INTERFACE = Retriever + # the result of _get_class is a Retriever + # it is translated into a RetrieverWrapper (which is a Component) + # in the 'parse' method below + DEFAULT_MODULE = "neo4j_graphrag.retrievers" + REQUIRED_PARAMS = ["driver"] + + def parse(self, resolved_data: Optional[dict[str, Any]] = None) -> RetrieverWrapper: + retriever = cast(Retriever, super().parse(resolved_data)) + return RetrieverWrapper(retriever) + + +class RetrieverType(RootModel): # type: ignore + root: Union[RetrieverWrapper, RetrieverConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> RetrieverWrapper: + if isinstance(self.root, RetrieverWrapper): + return self.root + return self.root.parse(resolved_data) + + +class SimpleRAGPipelineConfig(TemplatePipelineConfig): + COMPONENTS: ClassVar[list[str]] = [ + "retriever", + "prompt_builder", + "generator", + ] + retriever: RetrieverType + prompt_template: Union[RagTemplate, str] = RagTemplate() + template_: Literal[PipelineType.SIMPLE_RAG_PIPELINE] = ( + PipelineType.SIMPLE_RAG_PIPELINE + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def _get_retriever(self) -> RetrieverWrapper: + retriever = self.retriever.parse(self._global_data) + return retriever + + def _get_prompt_builder(self) -> PromptBuilder: + if isinstance(self.prompt_template, str): + return PromptBuilder(RagTemplate(template=self.prompt_template)) + return PromptBuilder(self.prompt_template) + + def _get_generator(self) -> Generator: + llm = self.get_default_llm() + return Generator(llm) + + def _get_connections(self) -> list[ConnectionDefinition]: + connections = [ + ConnectionDefinition( + start="retriever", + end="prompt_builder", + input_config={"context": "retriever.result"}, + ), + ConnectionDefinition( + start="prompt_builder", + end="generator", + input_config={ + "prompt": "prompt_builder.prompt", + }, + ), + ] + return connections + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + # query_text: str = "", + # examples: str = "", + # retriever_config: Optional[dict[str, Any]] = None, + # return_context: bool | None = None, + run_params = { + "retriever": { + "query_text": user_input["query_text"], + **user_input.get("retriever_config", {}), + }, + "prompt_builder": { + "query_text": user_input["query_text"], + "examples": user_input.get("examples", ""), + }, + "generator": {}, + } + return run_params diff --git a/src/neo4j_graphrag/experimental/pipeline/config/types.py b/src/neo4j_graphrag/experimental/pipeline/config/types.py index 48f91f485..d90a7d6af 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/types.py @@ -24,3 +24,4 @@ class PipelineType(str, enum.Enum): NONE = "none" SIMPLE_KG_PIPELINE = "SimpleKGPipeline" + SIMPLE_RAG_PIPELINE = "SimpleRAGPipeline" diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 77526785f..56346dbee 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -363,7 +363,9 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: raise PipelineDefinitionError( f"Parameter '{param}' already mapped to {self.param_mapping[task.name][param]}" ) - if param not in task.component.component_inputs: + if param not in task.component.component_inputs and ( + not task.component.anonymous_input_allowed + ): raise PipelineDefinitionError( f"Parameter '{param}' is not a valid input for component '{task.name}' of type '{task.component.__class__.__name__}'" )