Skip to content

Commit 84a701f

Browse files
Pouyanpitgasser-nv
authored andcommitted
fix(jailbreak): handle URL joining with/without trailing slashes (#1346)
1 parent 30f9cac commit 84a701f

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

nemoguardrails/library/jailbreak_detection/request.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,30 @@
3131
import asyncio
3232
import logging
3333
from typing import Optional
34+
from urllib.parse import urljoin
3435

3536
import aiohttp
3637

3738
log = logging.getLogger(__name__)
3839

3940

41+
def join_nim_url(base_url: str, classification_path: str) -> str:
42+
"""Join NIM base URL with classification path, handling trailing/leading slashes.
43+
44+
Args:
45+
base_url: The base NIM URL (with or without trailing slash)
46+
classification_path: The classification endpoint path (with or without leading slash)
47+
48+
Returns:
49+
Properly joined URL
50+
"""
51+
# Ensure base_url ends with '/' for proper urljoin behavior
52+
normalized_base = base_url.rstrip("/") + "/"
53+
# Remove leading slash from classification path to ensure relative joining
54+
normalized_path = classification_path.lstrip("/")
55+
return urljoin(normalized_base, normalized_path)
56+
57+
4058
async def jailbreak_detection_heuristics_request(
4159
prompt: str,
4260
api_url: str = "http://localhost:1337/heuristics",
@@ -101,14 +119,12 @@ async def jailbreak_nim_request(
101119
nim_auth_token: Optional[str],
102120
nim_classification_path: str,
103121
):
104-
from urllib.parse import urljoin
105-
106122
headers = {"Content-Type": "application/json", "Accept": "application/json"}
107123
payload = {
108124
"input": prompt,
109125
}
110126

111-
endpoint = urljoin(nim_url, nim_classification_path)
127+
endpoint = join_nim_url(nim_url, nim_classification_path)
112128
try:
113129
async with aiohttp.ClientSession() as session:
114130
try:

tests/test_jailbreak_request.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,51 @@ class TestJailbreakRequestChanges:
2525
"""Test jailbreak request function changes introduced in this PR."""
2626

2727
def test_url_joining_logic(self):
28-
"""Test that URL joining works correctly using urljoin."""
28+
"""Test that URL joining works correctly with all slash combinations."""
29+
from nemoguardrails.library.jailbreak_detection.request import join_nim_url
30+
2931
test_cases = [
3032
(
3133
"http://localhost:8000/v1",
3234
"classify",
33-
"http://localhost:8000/classify",
34-
), # v1 replaced by classify
35+
"http://localhost:8000/v1/classify",
36+
),
3537
(
3638
"http://localhost:8000/v1/",
3739
"classify",
3840
"http://localhost:8000/v1/classify",
39-
), # trailing slash preserves v1
41+
),
42+
(
43+
"http://localhost:8000/v1",
44+
"/classify",
45+
"http://localhost:8000/v1/classify",
46+
),
4047
(
41-
"http://localhost:8000",
42-
"v1/classify",
48+
"http://localhost:8000/v1/",
49+
"/classify",
4350
"http://localhost:8000/v1/classify",
4451
),
52+
("http://localhost:8000", "classify", "http://localhost:8000/classify"),
53+
("http://localhost:8000", "/classify", "http://localhost:8000/classify"),
54+
("http://localhost:8000/", "classify", "http://localhost:8000/classify"),
4555
("http://localhost:8000/", "/classify", "http://localhost:8000/classify"),
56+
(
57+
"http://localhost:8000/api/v1",
58+
"classify",
59+
"http://localhost:8000/api/v1/classify",
60+
),
61+
(
62+
"http://localhost:8000/api/v1/",
63+
"/classify",
64+
"http://localhost:8000/api/v1/classify",
65+
),
4666
]
4767

48-
for base_url, path, expected_url in test_cases:
49-
result = urljoin(base_url, path)
68+
for base_url, classification_path, expected_url in test_cases:
69+
result = join_nim_url(base_url, classification_path)
5070
assert (
5171
result == expected_url
52-
), f"urljoin({base_url}, {path}) should equal {expected_url}"
72+
), f"join_nim_url({base_url}, {classification_path}) should equal {expected_url}, got {result}"
5373

5474
def test_auth_header_logic(self):
5575
"""Test the authorization header logic."""

0 commit comments

Comments
 (0)