Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/routes/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ def fill_form(form: FormFill, db: Session = Depends(get_db)):
pdf_form_path=fetched_template.pdf_path,
)

if not path:
raise AppError("PDF generation failed", status_code=400)

submission = FormSubmission(**form.model_dump(), output_pdf_path=path)
return create_form(db, submission)
211 changes: 134 additions & 77 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,167 @@
import json
import logging
import os
import time
import requests
from requests.exceptions import Timeout, RequestException

logger = logging.getLogger("fireform.llm")

# Configuration constants
LLM_REQUEST_TIMEOUT_SECONDS = 120
LLM_MAX_RETRIES = 3
LLM_RETRY_BASE_DELAY_SECONDS = 2


class LLM:
def __init__(self, transcript_text=None, target_fields=None, json=None):
if json is None:
json = {}
self._transcript_text = transcript_text # str
self._target_fields = target_fields # List, contains the template field.
self._json = json # dictionary
self._transcript_text = transcript_text
self._target_fields = target_fields
self._json = json

def type_check_all(self):
if type(self._transcript_text) is not str:
raise TypeError(
f"ERROR in LLM() attributes ->\
Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
f"ERROR in LLM() attributes -> "
f"Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
)
elif type(self._target_fields) is not list:
raise TypeError(
f"ERROR in LLM() attributes ->\
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
f"ERROR in LLM() attributes -> "
f"Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
)

def build_prompt(self, current_field):
"""
This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
@params: current_field -> represents the current element of the json that is being prompted.
Creates a specific prompt for each target field.
"""
prompt = f"""
prompt = f"""
SYSTEM PROMPT:
You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
only a single string containing the identified value for the JSON field.
You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
only a single string containing the identified value for the JSON field.
If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
If you don't identify the value in the provided text, return "-1".
---
DATA:
Target JSON field to find in text: {current_field}

TEXT: {self._transcript_text}
"""

return prompt

def main_loop(self):
timeout = 30
max_retries = 3
def _call_ollama(self, prompt, field_name):
"""
Send a prompt to Ollama with timeout and retry logic.
"""
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

# self.type_check_all()
total_fields = len(self._target_fields)
for i, field in enumerate(self._target_fields.keys(), 1):
prompt = self.build_prompt(field)
# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
}

json_data = None
payload = {
"model": "mistral",
"prompt": prompt,
"stream": False,
}

last_exception = None

for attempt in range(1, LLM_MAX_RETRIES + 1):
try:
for attempt in range(max_retries):
try:
response = requests.post(ollama_url, json=payload, timeout=timeout)
response.raise_for_status()
json_data = response.json()
break
except Timeout:
print(f"Ollama request timed out (attempt {attempt+1})")
except RequestException as e:
print(f"Ollama request failed: {e}")
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
logger.info(
"LLM request for field '%s' (attempt %d/%d)",
field_name,
attempt,
LLM_MAX_RETRIES,
)

response = requests.post(
ollama_url,
json=payload,
timeout=LLM_REQUEST_TIMEOUT_SECONDS,
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

if json_data is None:
raise RuntimeError("Failed to get response from Ollama after retries.")
else:
# parse response
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)
print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.")

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
print(json.dumps(self._json, indent=2))
print("--------- extracted data ---------")
response.raise_for_status()

json_data = response.json()
result = json_data["response"]

logger.info(
"LLM response for field '%s': %s",
field_name,
result[:100] if len(result) > 100 else result,
)

return result

except requests.exceptions.Timeout as exc:
last_exception = exc
logger.warning(
"LLM request timed out for field '%s' (attempt %d/%d)",
field_name,
attempt,
LLM_MAX_RETRIES,
)

except requests.exceptions.ConnectionError as exc:
last_exception = exc
logger.warning(
"Cannot connect to Ollama for field '%s' (attempt %d/%d)",
field_name,
attempt,
LLM_MAX_RETRIES,
)

except requests.exceptions.HTTPError as exc:
last_exception = exc
if response.status_code >= 500:
logger.warning(
"Ollama server error %d for field '%s' (attempt %d/%d)",
response.status_code,
field_name,
attempt,
LLM_MAX_RETRIES,
)
else:
# Client errors (4xx) should not be retried
raise RuntimeError(
f"Ollama returned client error {response.status_code} "
f"for field '{field_name}': {exc}"
) from exc

# Exponential backoff before retry
if attempt < LLM_MAX_RETRIES:
delay = LLM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempt - 1))
logger.info("Retrying in %d seconds...", delay)
time.sleep(delay)

# All retries exhausted
raise RuntimeError(
f"LLM extraction failed for field '{field_name}' after "
f"{LLM_MAX_RETRIES} attempts: {last_exception}"
)

def main_loop(self):
"""
Iterate over all target fields, extract values from the LLM,
and build the result JSON.
"""
logger.info(
"Starting LLM extraction for %d fields",
len(self._target_fields) if self._target_fields else 0,
)

for field in self._target_fields.keys():
prompt = self.build_prompt(field)
parsed_response = self._call_ollama(prompt, field_name=field)
self.add_response_to_json(field, parsed_response)

logger.info("LLM extraction complete. Result:\n%s", json.dumps(self._json, indent=2))

return self

def add_response_to_json(self, field, value):
"""
this method adds the following value under the specified field,
or under a new field if the field doesn't exist, to the json dict
Adds the extracted value under the specified field in the JSON dict.
"""
value = value.strip().replace('"', "")
parsed_value = None
Expand All @@ -123,27 +181,26 @@ def add_response_to_json(self, field, value):

def handle_plural_values(self, plural_value):
"""
This method handles plural values.
Takes in strings of the form 'value1; value2; value3; ...; valueN'
returns a list with the respective values -> [value1, value2, value3, ..., valueN]
Handles plural values separated by semicolons.
'value1; value2; value3' → ['value1', 'value2', 'value3']
"""
if ";" not in plural_value:
raise ValueError(
f"Value is not plural, doesn't have ; separator, Value: {plural_value}"
)

print(
f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
)
logger.debug("Formatting plural values for input: %s", plural_value)
values = plural_value.split(";")

# Remove trailing leading whitespace
for i in range(len(values)):
values[i] = values[i].lstrip()
current = i + 1
if current < len(values):
clean_value = values[current].lstrip()
values[current] = clean_value

print(f"\t[LOG]: Resulting formatted list of values: {values}")
logger.debug("Resulting formatted list: %s", values)

return values

def get_data(self):
return self._json
return self._json