Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/routes/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def fill_form(form: FormFill, db: Session = Depends(get_db)):
)

submission = FormSubmission(**form.model_dump(), output_pdf_path=path)
return create_form(db, submission)
return create_form(db, submission)
8 changes: 7 additions & 1 deletion api/schemas/forms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator

class FormFill(BaseModel):
template_id: int
input_text: str

@field_validator("input_text")
def validate_input_text(cls, value):
if not value or not value.strip():
raise ValueError("Input text cannot be empty")
return value


class FormFillResponse(BaseModel):
id: int
Expand Down
62 changes: 62 additions & 0 deletions api/services/prompt_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
def build_extraction_prompt(input_text: str) -> str:
return f"""
You are an AI system that extracts structured information from incident reports.
Your task is to extract ONLY information explicitly present in the input text.

STRICT RULES:
- Do NOT infer or guess missing information
- If a field is not clearly mentioned, return an empty string ""
- Do NOT add any extra fields beyond those specified
- Do NOT modify or reinterpret values

Extract the following fields:
- name
- location
- date (YYYY-MM-DD if possible)
- incident_type
- description

Return ONLY valid JSON. Do not include any extra text, explanation, or formatting outside JSON.
The output MUST be a valid JSON object and parsable by json.loads().
Format:
{{
"name": "",
"location": "",
"date": "",
"incident_type": "",
"description": ""
}}

Example:

Input:
Fire reported near Central Park on Jan 5 involving a vehicle.

Output:
{{
"name": "",
"location": "Central Park",
"date": "2024-01-05",
"incident_type": "fire",
"description": "Fire involving a vehicle"
}}

Negative Example (DO NOT DO THIS):

Incorrect Output:
(This output is incorrect because it includes inferred/assumed values)
{{
"location": "Central Park (assumed)",
"date": "2024-01-05"
}}

Correct Output:
{{
"location": "Central Park",
"date": ""
}}

Now extract strictly from the following input (follow all rules above):

{input_text}
"""
63 changes: 24 additions & 39 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import requests
from api.services.prompt_builder import build_extraction_prompt
from requests.exceptions import Timeout, RequestException


Expand All @@ -24,47 +25,35 @@ def type_check_all(self):
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
)

def build_prompt(self, current_field):
"""
This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
@params: current_field -> represents the current element of the json that is being prompted.
"""
prompt = f"""
SYSTEM PROMPT:
You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
only a single string containing the identified value for the JSON field.
If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
If you don't identify the value in the provided text, return "-1".
---
DATA:
Target JSON field to find in text: {current_field}

TEXT: {self._transcript_text}
"""

return prompt

def main_loop(self):
timeout = 30
max_retries = 3

# self.type_check_all()
total_fields = len(self._target_fields)
for i, field in enumerate(self._target_fields.keys(), 1):
prompt = self.build_prompt(field)
# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"

for i, field in enumerate(self._target_fields, 1):
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

base_prompt = build_extraction_prompt(self._transcript_text)

prompt = f"""
{base_prompt}

Focus specifically on extracting the value for this field:
{field}

Return only the extracted value as a plain string. Do not return JSON.
"""

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
"stream": False,
}

json_data = None

try:
for attempt in range(max_retries):
try:
Expand All @@ -76,22 +65,21 @@ def main_loop(self):
print(f"Ollama request timed out (attempt {attempt+1})")
except RequestException as e:
print(f"Ollama request failed: {e}")

except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
f"Could not connect to Ollama at {ollama_url}."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

if json_data is None:
raise RuntimeError("Failed to get response from Ollama after retries.")
else:
# parse response
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)
print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.")

parsed_response = json_data["response"]
self.add_response_to_json(field, parsed_response)

print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.")

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
Expand Down Expand Up @@ -124,8 +112,6 @@ def add_response_to_json(self, field, value):
def handle_plural_values(self, plural_value):
"""
This method handles plural values.
Takes in strings of the form 'value1; value2; value3; ...; valueN'
returns a list with the respective values -> [value1, value2, value3, ..., valueN]
"""
if ";" not in plural_value:
raise ValueError(
Expand All @@ -137,7 +123,6 @@ def handle_plural_values(self, plural_value):
)
values = plural_value.split(";")

# Remove trailing leading whitespace
for i in range(len(values)):
values[i] = values[i].lstrip()

Expand All @@ -146,4 +131,4 @@ def handle_plural_values(self, plural_value):
return values

def get_data(self):
return self._json
return self._json