Skip to content

Commit 43a7d6b

Browse files
rakduttaNAYANAR0502
authored andcommitted
Fix: issue Non-standard redirect handling in _validate_gateway_url for STREAMABLEHTTP transport (#1425)
* debug Signed-off-by: rakdutta <[email protected]> * redirect -steamblehttp Signed-off-by: rakdutta <[email protected]> * remove addtional line Signed-off-by: rakdutta <[email protected]> * test Signed-off-by: rakdutta <[email protected]> * validate gateway Signed-off-by: rakdutta <[email protected]> * ruff Signed-off-by: rakdutta <[email protected]> * add doctring and doctest in observability.py Signed-off-by: rakdutta <[email protected]> * ruff Signed-off-by: rakdutta <[email protected]> * flake8 Signed-off-by: rakdutta <[email protected]> --------- Signed-off-by: rakdutta <[email protected]>
1 parent 2e6d516 commit 43a7d6b

File tree

5 files changed

+345
-74
lines changed

5 files changed

+345
-74
lines changed

mcpgateway/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def parse_issuers(cls, v: Any) -> list[str]:
914914

915915
# Validation Gateway URL
916916
gateway_validation_timeout: int = 5 # seconds
917+
gateway_max_redirects: int = 5
917918

918919
filelock_name: str = "gateway_service_leader.lock"
919920

mcpgateway/routers/observability.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,27 @@ def list_traces(
8080
8181
Returns:
8282
List[ObservabilityTraceRead]: List of traces matching filters
83+
84+
Examples:
85+
>>> import mcpgateway.routers.observability as obs
86+
>>> class FakeTrace:
87+
... def __init__(self, trace_id='t1'):
88+
... self.trace_id = trace_id
89+
... self.name = 'n'
90+
... self.start_time = None
91+
... self.end_time = None
92+
... self.duration_ms = 100
93+
... self.status = 'ok'
94+
... self.http_method = 'GET'
95+
... self.http_url = '/'
96+
... self.http_status_code = 200
97+
... self.user_email = 'u'
98+
>>> class FakeService:
99+
... def query_traces(self, **kwargs):
100+
... return [FakeTrace('t1')]
101+
>>> obs.ObservabilityService = FakeService
102+
>>> obs.list_traces(db=None)[0].trace_id
103+
't1'
83104
"""
84105
service = ObservabilityService()
85106
traces = service.query_traces(
@@ -138,6 +159,27 @@ def query_traces_advanced(
138159
139160
Raises:
140161
HTTPException: 400 error if request body is invalid
162+
163+
Examples:
164+
>>> from fastapi import HTTPException
165+
>>> try:
166+
... query_traces_advanced({"start_time": "not-a-date"}, db=None)
167+
... except HTTPException as e:
168+
... (e.status_code, "Invalid request body" in str(e.detail))
169+
(400, True)
170+
171+
>>> import mcpgateway.routers.observability as obs
172+
>>> class FakeTrace:
173+
... def __init__(self):
174+
... self.trace_id = 'tx'
175+
... self.name = 'n'
176+
177+
>>> class FakeService2:
178+
... def query_traces(self, **kwargs):
179+
... return [FakeTrace()]
180+
>>> obs.ObservabilityService = FakeService2
181+
>>> obs.query_traces_advanced({}, db=None)[0].trace_id
182+
'tx'
141183
"""
142184
# Third-Party
143185
from pydantic import ValidationError
@@ -199,6 +241,24 @@ def get_trace(trace_id: str, db: Session = Depends(get_db)):
199241
200242
Raises:
201243
HTTPException: 404 if trace not found
244+
245+
Examples:
246+
>>> import mcpgateway.routers.observability as obs
247+
>>> class FakeService:
248+
... def get_trace_with_spans(self, db, trace_id):
249+
... return None
250+
>>> obs.ObservabilityService = FakeService
251+
>>> try:
252+
... obs.get_trace('missing', db=None)
253+
... except obs.HTTPException as e:
254+
... e.status_code
255+
404
256+
>>> class FakeService2:
257+
... def get_trace_with_spans(self, db, trace_id):
258+
... return {'trace_id': trace_id}
259+
>>> obs.ObservabilityService = FakeService2
260+
>>> obs.get_trace('found', db=None)['trace_id']
261+
'found'
202262
"""
203263
service = ObservabilityService()
204264
trace = service.get_trace_with_spans(db, trace_id)
@@ -235,6 +295,20 @@ def list_spans(
235295
236296
Returns:
237297
List[ObservabilitySpanRead]: List of spans matching filters
298+
299+
Examples:
300+
>>> import mcpgateway.routers.observability as obs
301+
>>> class FakeSpan:
302+
... def __init__(self):
303+
... self.span_id = 's1'
304+
... self.trace_id = 't1'
305+
... self.name = 'op'
306+
>>> class FakeService:
307+
... def query_spans(self, **kwargs):
308+
... return [FakeSpan()]
309+
>>> obs.ObservabilityService = FakeService
310+
>>> obs.list_spans(db=None)[0].span_id
311+
's1'
238312
"""
239313
service = ObservabilityService()
240314
spans = service.query_spans(
@@ -266,6 +340,16 @@ def cleanup_old_traces(
266340
267341
Returns:
268342
dict: Number of deleted traces and cutoff time
343+
344+
Examples:
345+
>>> import mcpgateway.routers.observability as obs
346+
>>> class FakeService:
347+
... def delete_old_traces(self, db, cutoff):
348+
... return 5
349+
>>> obs.ObservabilityService = FakeService
350+
>>> res = obs.cleanup_old_traces(days=7, db=None)
351+
>>> res['deleted']
352+
5
269353
"""
270354
service = ObservabilityService()
271355
cutoff_time = datetime.now() - timedelta(days=days)
@@ -358,6 +442,41 @@ def export_traces(
358442
359443
Raises:
360444
HTTPException: 400 error if format is invalid or export fails
445+
446+
Examples:
447+
>>> from fastapi import HTTPException
448+
>>> try:
449+
... export_traces({}, format="xml", db=None)
450+
... except HTTPException as e:
451+
... (e.status_code, "format must be one of" in str(e.detail))
452+
(400, True)
453+
>>> import mcpgateway.routers.observability as obs
454+
>>> from datetime import datetime
455+
>>> class FakeTrace:
456+
... def __init__(self):
457+
... self.trace_id = 'tx'
458+
... self.name = 'name'
459+
... self.start_time = datetime(2025,1,1)
460+
... self.end_time = None
461+
... self.duration_ms = 100
462+
... self.status = 'ok'
463+
... self.http_method = 'GET'
464+
... self.http_url = '/'
465+
... self.http_status_code = 200
466+
... self.user_email = 'u'
467+
>>> class FakeService:
468+
... def query_traces(self, **kwargs):
469+
... return [FakeTrace()]
470+
>>> obs.ObservabilityService = FakeService
471+
>>> out = obs.export_traces({}, format='json', db=None)
472+
>>> out[0]['trace_id']
473+
'tx'
474+
>>> resp = obs.export_traces({}, format='csv', db=None)
475+
>>> hasattr(resp, 'media_type') and 'csv' in resp.media_type
476+
True
477+
>>> resp2 = obs.export_traces({}, format='ndjson', db=None)
478+
>>> type(resp2).__name__
479+
'StreamingResponse'
361480
"""
362481
# Standard
363482
import csv
@@ -437,6 +556,13 @@ def export_traces(
437556
elif format == "ndjson":
438557
# Newline-delimited JSON (streaming)
439558
def generate():
559+
"""Yield newline-delimited JSON strings for each trace.
560+
561+
This nested generator is used to stream NDJSON responses.
562+
563+
Yields:
564+
str: A JSON-encoded line (with trailing newline) for a trace.
565+
"""
440566
for t in traces:
441567
# Standard
442568
import json
@@ -475,7 +601,32 @@ def get_query_performance(hours: int = Query(24, ge=1, le=168, description="Time
475601
476602
Returns:
477603
dict: Performance analytics
604+
605+
Examples:
606+
>>> import mcpgateway.routers.observability as obs
607+
>>> class EmptyDB:
608+
... def query(self, *a, **k):
609+
... return self
610+
... def filter(self, *a, **k):
611+
... return self
612+
... def all(self):
613+
... return []
614+
>>> obs.get_query_performance(hours=1, db=EmptyDB())['total_traces']
615+
0
616+
617+
>>> class SmallDB:
618+
... def query(self, *a, **k):
619+
... return self
620+
... def filter(self, *a, **k):
621+
... return self
622+
... def all(self):
623+
... return [(10,), (20,), (30,), (40,)]
624+
>>> res = obs.get_query_performance(hours=1, db=SmallDB())
625+
>>> res['total_traces']
626+
4
627+
478628
"""
629+
479630
# Third-Party
480631

481632
# First-Party

0 commit comments

Comments
 (0)