diff --git a/api/routes/forms.py b/api/routes/forms.py index cee5356..1a56687 100644 --- a/api/routes/forms.py +++ b/api/routes/forms.py @@ -24,4 +24,4 @@ def fill_form(form: FormFill, db: Session = Depends(get_db)): ) submission = FormSubmission(**form.model_dump(), output_pdf_path=path) - return create_form(db, submission) + return create_form(db, submission) \ No newline at end of file diff --git a/api/schemas/forms.py b/api/schemas/forms.py index 3cce650..bf6957e 100644 --- a/api/schemas/forms.py +++ b/api/schemas/forms.py @@ -1,9 +1,15 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator class FormFill(BaseModel): template_id: int input_text: str + @field_validator("input_text") + def validate_input_text(cls, value): + if not value or not value.strip(): + raise ValueError("Input text cannot be empty") + return value + class FormFillResponse(BaseModel): id: int diff --git a/api/services/prompt_builder.py b/api/services/prompt_builder.py new file mode 100644 index 0000000..843c7cd --- /dev/null +++ b/api/services/prompt_builder.py @@ -0,0 +1,62 @@ +def build_extraction_prompt(input_text: str) -> str: + return f""" +You are an AI system that extracts structured information from incident reports. +Your task is to extract ONLY information explicitly present in the input text. + +STRICT RULES: +- Do NOT infer or guess missing information +- If a field is not clearly mentioned, return an empty string "" +- Do NOT add any extra fields beyond those specified +- Do NOT modify or reinterpret values + +Extract the following fields: +- name +- location +- date (YYYY-MM-DD if possible) +- incident_type +- description + +Return ONLY valid JSON. Do not include any extra text, explanation, or formatting outside JSON. +The output MUST be a valid JSON object and parsable by json.loads(). +Format: +{{ + "name": "", + "location": "", + "date": "", + "incident_type": "", + "description": "" +}} + +Example: + +Input: +Fire reported near Central Park on Jan 5 involving a vehicle. + +Output: +{{ + "name": "", + "location": "Central Park", + "date": "2024-01-05", + "incident_type": "fire", + "description": "Fire involving a vehicle" +}} + +Negative Example (DO NOT DO THIS): + +Incorrect Output: +(This output is incorrect because it includes inferred/assumed values) +{{ + "location": "Central Park (assumed)", + "date": "2024-01-05" +}} + +Correct Output: +{{ + "location": "Central Park", + "date": "" +}} + +Now extract strictly from the following input (follow all rules above): + +{input_text} +""" \ No newline at end of file diff --git a/src/llm.py b/src/llm.py index 3621187..a3f3099 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,6 +1,7 @@ import json import os import requests +from api.services.prompt_builder import build_extraction_prompt from requests.exceptions import Timeout, RequestException @@ -24,47 +25,35 @@ def type_check_all(self): Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" ) - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ - prompt = f""" - SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". - --- - DATA: - Target JSON field to find in text: {current_field} - - TEXT: {self._transcript_text} - """ - - return prompt - def main_loop(self): timeout = 30 max_retries = 3 - # self.type_check_all() total_fields = len(self._target_fields) - for i, field in enumerate(self._target_fields.keys(), 1): - prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" + + for i, field in enumerate(self._target_fields, 1): ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") ollama_url = f"{ollama_host}/api/generate" + base_prompt = build_extraction_prompt(self._transcript_text) + + prompt = f""" +{base_prompt} + +Focus specifically on extracting the value for this field: +{field} + +Return only the extracted value as a plain string. Do not return JSON. +""" + payload = { "model": "mistral", "prompt": prompt, - "stream": False, # don't really know why --> look into this later. + "stream": False, } json_data = None + try: for attempt in range(max_retries): try: @@ -76,22 +65,21 @@ def main_loop(self): print(f"Ollama request timed out (attempt {attempt+1})") except RequestException as e: print(f"Ollama request failed: {e}") + except requests.exceptions.ConnectionError: raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." + f"Could not connect to Ollama at {ollama_url}." ) except requests.exceptions.HTTPError as e: raise RuntimeError(f"Ollama returned an error: {e}") if json_data is None: raise RuntimeError("Failed to get response from Ollama after retries.") - else: - # parse response - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) - print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.") + + parsed_response = json_data["response"] + self.add_response_to_json(field, parsed_response) + + print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.") print("----------------------------------") print("\t[LOG] Resulting JSON created from the input text:") @@ -124,8 +112,6 @@ def add_response_to_json(self, field, value): def handle_plural_values(self, plural_value): """ This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] """ if ";" not in plural_value: raise ValueError( @@ -137,7 +123,6 @@ def handle_plural_values(self, plural_value): ) values = plural_value.split(";") - # Remove trailing leading whitespace for i in range(len(values)): values[i] = values[i].lstrip() @@ -146,4 +131,4 @@ def handle_plural_values(self, plural_value): return values def get_data(self): - return self._json + return self._json \ No newline at end of file