Skip to content

Commit 6c6705a

Browse files
committed
feat: update slang prediction endpoint to support batch input
1 parent 3fa5dfa commit 6c6705a

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

apps/classifier/app.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
PredictionRequest,
66
PredictionResponse,
77
)
8-
from model import predict
8+
from model import predict, predict_batch
99

1010
app = FastAPI()
1111

@@ -26,6 +26,8 @@ async def improve_reply_predict(data: PredictionRequest):
2626

2727
@app.post("/slang-predict", response_model=SlangPredictionResponse)
2828
async def slang_predict(data: SlangPredictionRequest):
29-
text = data.input
30-
predicted = predict(text, type="slang")
31-
return {"predicted": predicted[0], "probability": predicted[1]}
29+
text = data.inputs
30+
predicted = predict_batch(text, type="slang")
31+
return {
32+
"predictions": [{"predicted": p[0], "probability": p[1]} for p in predicted]
33+
}

apps/classifier/schemas.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
from pydantic import BaseModel
2+
from typing import List
23

34

45
class SlangPredictionRequest(BaseModel):
5-
input: str
6+
inputs: List[str]
67

78
class Config:
89
json_schema_extra = {
910
"example": {
10-
"input": "X같네",
11+
"inputs": ["X같네"],
1112
}
1213
}
1314

1415

15-
class SlangPredictionResponse(BaseModel):
16+
class SlangPredictionItem(BaseModel):
1617
predicted: str
1718
probability: float
1819

20+
21+
class SlangPredictionResponse(BaseModel):
22+
predictions: List[SlangPredictionItem]
23+
1924
class Config:
2025
json_schema_extra = {
21-
"example": {
22-
"predicted": "욕설",
23-
"probability": 0.99,
24-
}
26+
"example": {"predictions": [{"predicted": "욕설", "probability": 0.99}]}
2527
}
2628

2729

0 commit comments

Comments
 (0)