Skip to content

Commit

Permalink
refactor: Refactor prompt node (#4580)
Browse files Browse the repository at this point in the history
* Refactor prompt structure

* Refactor prompt tests structure

* Fix pylint

* Move TestPromptTemplateSyntax to test_prompt_template.py
  • Loading branch information
silvanocerza authored Apr 3, 2023
1 parent c202866 commit 1cc4c9c
Show file tree
Hide file tree
Showing 8 changed files with 996 additions and 953 deletions.
5 changes: 4 additions & 1 deletion haystack/nodes/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from haystack.nodes.prompt.prompt_node import PromptNode, PromptTemplate, PromptModel, BaseOutputParser, AnswerParser
from haystack.nodes.prompt.prompt_node import PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplate
from haystack.nodes.prompt.prompt_model import PromptModel
from haystack.nodes.prompt.shapers import BaseOutputParser, AnswerParser
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
155 changes: 155 additions & 0 deletions haystack/nodes/prompt/prompt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Dict, List, Optional, Tuple, Union, Any, Type, overload
import logging

import torch

from haystack.nodes.base import BaseComponent
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, instruction_following_models
from haystack.schema import Document, MultiLabel

logger = logging.getLogger(__name__)


class PromptModel(BaseComponent):
"""
The PromptModel class is a component that uses a pre-trained model to perform tasks defined in a prompt. Out of
the box, it supports model invocation layers for:
- Hugging Face transformers (all text2text-generation and text-generation models)
- OpenAI InstructGPT models
- Azure OpenAI InstructGPT models
Although it's possible to use PromptModel to make prompt invocations on the underlying model, use
PromptNode to interact with the model. PromptModel instances are a way for multiple
PromptNode instances to use a single PromptNode, and thus save computational resources.
For more details, refer to [PromptModels](https://docs.haystack.deepset.ai/docs/prompt_node#models).
"""

outgoing_edges = 1

def __init__(
self,
model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100,
api_key: Optional[str] = None,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None,
model_kwargs: Optional[Dict] = None,
):
"""
Creates an instance of PromptModel.
:param model_name_or_path: The name or path of the underlying model.
:param max_length: The maximum length of the output text generated by the model.
:param api_key: The API key to use for the model.
:param use_auth_token: The Hugging Face token to use.
:param use_gpu: Whether to use GPU or not.
:param devices: The devices to use where the model is loaded.
:param invocation_layer_class: The custom invocation layer class to use. If None, known invocation layers are used.
:param model_kwargs: Additional keyword arguments passed to the underlying model.
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
Azure OpenAI API endpoint, usually in the form `https://<your-endpoint>.openai.azure.com') and
azure_deployment_name (the name of the Azure OpenAI API deployment). You should add these parameters
in the `model_kwargs` dictionary.
"""
super().__init__()
self.model_name_or_path = model_name_or_path
self.max_length = max_length
self.api_key = api_key
self.use_auth_token = use_auth_token
self.use_gpu = use_gpu
self.devices = devices

self.model_kwargs = model_kwargs if model_kwargs else {}
self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class)
is_instruction_following: bool = any(m in model_name_or_path for m in instruction_following_models())
if not is_instruction_following:
logger.warning(
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction-following tasks. "
"Many of the default prompts and PromptTemplates may not work as intended. "
"Use custom prompts and PromptTemplates specific to the %s model",
model_name_or_path,
)

def create_invocation_layer(
self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]]
) -> PromptModelInvocationLayer:
kwargs = {
"api_key": self.api_key,
"use_auth_token": self.use_auth_token,
"use_gpu": self.use_gpu,
"devices": self.devices,
}
all_kwargs = {**self.model_kwargs, **kwargs}

if invocation_layer_class:
return invocation_layer_class(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)
# search all invocation layer classes and find the first one that supports the model,
# then create an instance of that invocation layer
for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
return invocation_layer(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)
raise ValueError(
f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
f" Currently supported invocation layers are: {PromptModelInvocationLayer.invocation_layer_providers}"
f" You can implement and provide custom invocation layer for {self.model_name_or_path} by subclassing "
"PromptModelInvocationLayer."
)

def invoke(self, prompt: Union[str, List[str], List[Dict[str, str]]], **kwargs) -> List[str]:
"""
Takes in a prompt and returns a list of responses using the underlying invocation layer.
:param prompt: The prompt to use for the invocation. It can be a single prompt or a list of prompts.
:param kwargs: Additional keyword arguments to pass to the invocation layer.
:return: A list of model-generated responses for the prompt or prompts.
"""
output = self.model_invocation_layer.invoke(prompt=prompt, **kwargs)
return output

@overload
def _ensure_token_limit(self, prompt: str) -> str:
...

@overload
def _ensure_token_limit(self, prompt: List[Dict[str, str]]) -> List[Dict[str, str]]:
...

def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""Ensure that length of the prompt and answer is within the maximum token length of the PromptModel.
:param prompt: Prompt text to be sent to the generative model.
"""
return self.model_invocation_layer._ensure_token_limit(prompt=prompt)

def run(
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
) -> Tuple[Dict, str]:
raise NotImplementedError("This method should never be implemented in the derived class")

def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
raise NotImplementedError("This method should never be implemented in the derived class")

def __repr__(self):
return "{}({!r})".format(self.__class__.__name__, self.__dict__)
Loading

0 comments on commit 1cc4c9c

Please sign in to comment.