Skip to content

Commit 7df4f2f

Browse files
authored
Fix the copilot detector class inheritance (#903)
The class to inherit from is BaseClientDetector, the HeaderDetector is supposed to be composed of. Also add tests.
1 parent f60a1a0 commit 7df4f2f

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

src/codegate/clients/detector.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,14 @@ def client_name(self) -> ClientType:
160160
return ClientType.OPEN_INTERPRETER
161161

162162

163-
class CopilotDetector(HeaderDetector):
163+
class CopilotDetector(BaseClientDetector):
164164
"""
165165
Detector for Copilot client based on user agent
166166
"""
167167

168168
def __init__(self):
169-
super().__init__("user-agent", "Copilot")
169+
super().__init__()
170+
self.user_agent_detector = UserAgentDetector("Copilot")
170171

171172
@property
172173
def client_name(self) -> ClientType:

tests/clients/test_detector.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -272,20 +272,23 @@ async def get_json():
272272

273273

274274
class TestCopilotDetector:
275-
def test_successful_detection(self, mock_request):
275+
@pytest.mark.asyncio
276+
async def test_successful_detection(self, mock_request):
276277
detector = CopilotDetector()
277278
mock_request.headers = Headers({"user-agent": "Copilot"})
278-
assert detector.detect(mock_request) is True
279+
assert await detector.detect(mock_request) is True
279280
assert detector.client_name == ClientType.COPILOT
280281

281-
def test_failed_detection(self, mock_request):
282+
@pytest.mark.asyncio
283+
async def test_failed_detection(self, mock_request):
282284
detector = CopilotDetector()
283285
mock_request.headers = Headers({"user-agent": "Different Client"})
284-
assert detector.detect(mock_request) is False
286+
assert await detector.detect(mock_request) is False
285287

286-
def test_missing_user_agent(self, mock_request):
288+
@pytest.mark.asyncio
289+
async def test_missing_user_agent(self, mock_request):
287290
detector = CopilotDetector()
288-
assert detector.detect(mock_request) is False
291+
assert await detector.detect(mock_request) is False
289292

290293

291294
class TestDetectClient:
@@ -353,3 +356,15 @@ async def test_endpoint(request: Request):
353356

354357
result = await test_endpoint(mock_request)
355358
assert result == ClientType.KODU
359+
360+
@pytest.mark.asyncio
361+
async def test_copilot_detection_in_detect_client(self, mock_request):
362+
detect_client = DetectClient()
363+
mock_request.headers = Headers({"user-agent": "Copilot"})
364+
365+
@detect_client
366+
async def test_endpoint(request: Request):
367+
return request.state.detected_client
368+
369+
result = await test_endpoint(mock_request)
370+
assert result == ClientType.COPILOT

0 commit comments

Comments
 (0)