Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions arcsolve/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,27 @@

import asyncio
import sys
import urllib.parse
import webbrowser

from arcsolve.server import build_server


def _parse_redirect(raw: str) -> tuple[str, str | None]:
"""사용자가 붙여넣은 값에서 (code, state)를 뽑는다.

'code='가 들어 있으면 redirect URL(또는 쿼리스트링)로 보고 파싱하고, 아니면 입력 전체를
code로 취급한다(state 없음). state가 있으면 exchange_code가 CSRF 대조에 쓴다.
"""
if "code=" not in raw:
return raw, None
qs = urllib.parse.urlparse(raw).query or raw
params = urllib.parse.parse_qs(qs)
code = (params.get("code") or [raw])[0]
state = (params.get("state") or [None])[0]
return code, state


def _auth(name: str) -> None:
from arcsolve.services import available, load_service

Expand All @@ -33,15 +49,16 @@ def _auth(name: str) -> None:

url = client.authorize_url_for_login()
print("브라우저에서 아래 URL을 열어 로그인/동의한 뒤,")
print("리다이렉트된 주소(redirect_uri)?code=... 값을 복사해 붙여넣으세요.\n")
print("리다이렉트된 주소(redirect_uri) 전체를 그대로 붙여넣으세요(또는 ?code=... 값만).\n")
print(url + "\n")
try:
webbrowser.open(url)
except Exception:
pass

code = input("code = ").strip()
tok = asyncio.run(client.exchange_code(code))
raw = input("redirect URL 전체(또는 code) = ").strip()
code, state = _parse_redirect(raw)
tok = asyncio.run(client.exchange_code(code, state=state))
print("\n✅ 인증 완료. ~/.arcsolve/credentials.json 에 저장했습니다(권한 0600).")
print("호스트 설정의 env에 refresh_token을 직접 넣어도 됩니다(평문 노출 주의):")
print(f" {name.upper()}_REFRESH_TOKEN={tok.get('refresh_token')}")
Expand Down
19 changes: 17 additions & 2 deletions arcsolve/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,22 @@

import httpx

from arcsolve import __version__

DEFAULT_TIMEOUT = 10.0

# 식별용 기본 User-Agent. UA 누락 시 403을 주는 API(NWS·Wikipedia 등)를 구조적으로 예방하고,
# 서비스마다 UA 문자열을 손으로 박는 drift를 없앤다. 호출자가 명시한 UA가 항상 우선한다.
DEFAULT_USER_AGENT = f"arcsolve/{__version__} (+https://github.com/ArcSolver/ArcSolve-Kit)"


def _with_default_ua(headers: dict | None) -> dict:
"""기본 User-Agent를 깔고 호출자 헤더로 덮어쓴다(호출자 UA가 우선)."""
merged = {"User-Agent": DEFAULT_USER_AGENT}
if headers:
merged.update(headers)
return merged


class UpstreamError(RuntimeError):
"""상류 API가 4xx/5xx를 반환했을 때. payload에 원본 응답(JSON 또는 text)을 담는다."""
Expand Down Expand Up @@ -46,7 +60,8 @@ async def _request_raw(
"""
async with httpx.AsyncClient(timeout=timeout, transport=transport) as client:
r = await client.request(
method, url, headers=headers, params=params, data=data, json=json, files=files
method, url, headers=_with_default_ua(headers),
params=params, data=data, json=json, files=files,
)
if r.status_code >= 400:
try:
Expand Down Expand Up @@ -137,7 +152,7 @@ async def get_text(
빈 문자열을 돌려준다. transport 주입으로 네트워크 없이 테스트할 수 있다.
"""
async with httpx.AsyncClient(timeout=timeout, transport=transport) as client:
r = await client.request("GET", url, headers=headers, params=params)
r = await client.request("GET", url, headers=_with_default_ua(headers), params=params)
if r.status_code >= 400:
try:
payload: dict | str = r.json()
Expand Down
44 changes: 36 additions & 8 deletions arcsolve/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import json
import os
import secrets
import tempfile
import time
import urllib.parse
from dataclasses import dataclass, field
from pathlib import Path

import httpx

from arcsolve.http import DEFAULT_USER_AGENT

DEFAULT_STORE = Path.home() / ".arcsolve" / "credentials.json"
_TIMEOUT = 10.0

Expand Down Expand Up @@ -56,11 +59,24 @@ def update(self, service: str, **fields) -> None:
os.chmod(self.path.parent, 0o700)
except OSError:
pass
self.path.write_text(json.dumps(data, indent=2, ensure_ascii=False))
payload = json.dumps(data, indent=2, ensure_ascii=False)
# 원자적 교체: 같은 디렉터리에 임시 파일로 쓰고 os.replace로 갈아끼운다. 쓰기 도중
# 중단/충돌이 나도 기존 credentials.json(전 서비스 토큰)이 잘리거나 손상되지 않는다.
fd, tmp = tempfile.mkstemp(dir=self.path.parent, prefix=".credentials-", suffix=".tmp")
try:
os.chmod(self.path, 0o600)
except OSError:
pass
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(payload)
try:
os.chmod(tmp, 0o600) # 최종 파일 권한 0600 (replace는 임시파일 모드를 유지)
except OSError:
pass
os.replace(tmp, self.path)
except BaseException:
try:
os.unlink(tmp)
except OSError:
pass
raise


@dataclass
Expand All @@ -76,6 +92,7 @@ class OAuthClient:
store: TokenStore = field(default_factory=TokenStore)
transport: httpx.BaseTransport | None = None # 테스트 주입용
_verifier: str | None = field(default=None, init=False, repr=False)
_state: str | None = field(default=None, init=False, repr=False) # CSRF 방어용 state

async def access_token(self) -> str:
"""유효한 access token을 반환한다. 만료(또는 부재) 시 refresh로 자동 갱신."""
Expand All @@ -98,8 +115,17 @@ async def access_token(self) -> str:
self._save(tok, fallback_refresh=refresh)
return tok["access_token"]

async def exchange_code(self, code: str) -> dict:
"""authorization code를 토큰으로 교환하고 저장한다(최초 1회 인증)."""
async def exchange_code(self, code: str, state: str | None = None) -> dict:
"""authorization code를 토큰으로 교환하고 저장한다(최초 1회 인증).

`state`를 주면 authorize_url_for_login()이 만든 값과 대조한다(CSRF·인가코드 주입 방어).
수동 복붙 흐름에서 state를 모르면 None으로 생략할 수 있다(같은 프로세스에서 authorize URL을
만들었고 state도 함께 받은 경우에만 검증).
"""
if state is not None and self._state is not None and not secrets.compare_digest(
state, self._state
):
raise RuntimeError(f"{self.service}: OAuth state 불일치 — 인증을 처음부터 다시 하세요.")
data = {
"grant_type": "authorization_code",
"client_id": self.client_id,
Expand All @@ -115,6 +141,7 @@ async def exchange_code(self, code: str) -> dict:
def authorize_url_for_login(self) -> str:
verifier, challenge = _pkce_pair()
self._verifier = verifier
self._state = secrets.token_urlsafe(16) # 저장해 두고 exchange_code에서 대조
query = urllib.parse.urlencode(
{
"client_id": self.client_id,
Expand All @@ -123,16 +150,17 @@ def authorize_url_for_login(self) -> str:
"scope": " ".join(self.scopes),
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": secrets.token_urlsafe(16),
"state": self._state,
}
)
return f"{self.authorize_url}?{query}"

async def _post_token(self, data: dict) -> dict:
if self.client_secret:
data = {**data, "client_secret": self.client_secret}
headers = {"User-Agent": DEFAULT_USER_AGENT}
async with httpx.AsyncClient(timeout=_TIMEOUT, transport=self.transport) as client:
r = await client.post(self.token_url, data=data)
r = await client.post(self.token_url, data=data, headers=headers)
r.raise_for_status()
return r.json()

Expand Down
3 changes: 3 additions & 0 deletions changelog.d/core-hardening.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- core(http): 식별용 기본 User-Agent(`arcsolve/<version>`)를 모든 요청에 주입 — UA 누락 시 403을 주는 API(NWS·Wikipedia 등)를 구조적으로 예방하고 서비스별 UA 하드코딩 drift를 제거. 호출자가 명시한 UA는 항상 우선.
- core(oauth): OAuth `state`를 생성만 하던 것을 저장·대조하도록 보강(CSRF·인가코드 주입 방어). `exchange_code(code, state=?)` + `arcsolve auth`가 redirect URL 전체 붙여넣기를 받아 code/state를 파싱. 토큰 저장을 `tempfile`+`os.replace` 원자적 교체로 변경해 쓰기 중단 시 credentials.json 손상 방지. 토큰 엔드포인트에도 기본 UA 전송.
- core(pkg): 패키지 버전을 `arcsolve/__init__.py` 단일 출처로 통일(hatch dynamic version) — pyproject 이중 기록 제거로 릴리스 시 버전·UA·PyPI 메타데이터 drift 차단.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "arcsolve"
version = "0.1.0"
dynamic = ["version"]
description = "Bundle the public APIs of popular services as MCP tools and Claude Skills, verified against official contracts (FastMCP)."
readme = "README.md"
requires-python = ">=3.11"
Expand Down Expand Up @@ -39,6 +39,10 @@ Issues = "https://github.com/ArcSolver/ArcSolve-Kit/issues"
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.version]
# 버전 단일 출처: arcsolve/__init__.py의 __version__ (pyproject와 이중 기록 제거).
path = "arcsolve/__init__.py"

[tool.hatch.build.targets.wheel]
packages = ["arcsolve"]

Expand Down
23 changes: 23 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,29 @@ def test_parse_link_header_empty():
assert parse_link_header("") == {}


async def test_default_user_agent_applied_when_caller_omits():
seen = {}

async def handler(req):
seen["ua"] = req.headers.get("user-agent")
return httpx.Response(200, json={"ok": True})

await get_json("https://x", transport=_t(handler))
assert seen["ua"].startswith("arcsolve/") # 코어 기본 UA 주입


async def test_caller_user_agent_overrides_default():
seen = {}

async def handler(req):
seen["ua"] = req.headers.get("user-agent")
return httpx.Response(200, json={"ok": True})

# 서비스가 명시한 UA(예: NWS/Wikipedia 필수 UA)는 코어 기본값을 덮어쓴다.
await get_json("https://x", headers={"User-Agent": "custom/1.0"}, transport=_t(handler))
assert seen["ua"] == "custom/1.0"


async def test_4xx_raises_upstream_error_with_payload():
async def handler(req):
return httpx.Response(401, json={"code": -401, "msg": "bad"})
Expand Down
54 changes: 53 additions & 1 deletion tests/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,30 @@
import json
import os
import stat
from urllib.parse import parse_qs
from urllib.parse import parse_qs, urlparse

import httpx
import pytest

from arcsolve.oauth import OAuthClient, TokenStore


def _client(tmp_path, handler) -> OAuthClient:
return OAuthClient(
service="x",
token_url="https://t/token",
authorize_url="https://a/authorize",
client_id="cid",
scopes=["s"],
store=TokenStore(tmp_path / "cred.json"),
transport=httpx.MockTransport(handler),
)


async def _ok_token(req):
return httpx.Response(200, json={"access_token": "AT", "expires_in": 3600})


async def test_exchange_code_uses_pkce_and_saves(tmp_path):
seen = {}

Expand Down Expand Up @@ -42,7 +59,42 @@ async def handler(req):
assert saved["x"]["refresh_token"] == "RT"


async def test_exchange_code_rejects_mismatched_state(tmp_path):
client = _client(tmp_path, _ok_token)
client.authorize_url_for_login() # 내부 _state 생성
with pytest.raises(RuntimeError, match="state"):
await client.exchange_code("CODE", state="not-the-real-state")


async def test_exchange_code_accepts_matching_state(tmp_path):
client = _client(tmp_path, _ok_token)
url = client.authorize_url_for_login()
state = parse_qs(urlparse(url).query)["state"][0] # authorize URL에 실린 state
tok = await client.exchange_code("CODE", state=state) # 일치 → 통과
assert tok["access_token"] == "AT"


async def test_exchange_code_without_state_still_works(tmp_path):
# 수동 복붙 흐름에서 state를 모르면 생략 가능(후방호환).
client = _client(tmp_path, _ok_token)
client.authorize_url_for_login()
tok = await client.exchange_code("CODE")
assert tok["access_token"] == "AT"


def test_token_store_file_is_0600(tmp_path):
path = tmp_path / "sub" / "cred.json"
TokenStore(path).update("svc", access_token="AT")
assert stat.S_IMODE(os.stat(path).st_mode) == 0o600


def test_token_store_update_is_atomic_and_leaves_no_temp(tmp_path):
d = tmp_path / "sub"
store = TokenStore(d / "cred.json")
store.update("svc", access_token="AT")
store.update("svc2", refresh_token="RT") # 두 번째 갱신은 기존과 병합
files = sorted(p.name for p in d.iterdir())
assert files == ["cred.json"] # 임시 파일(.credentials-*.tmp) 잔재 없음
saved = json.loads((d / "cred.json").read_text())
assert saved["svc"]["access_token"] == "AT"
assert saved["svc2"]["refresh_token"] == "RT"
1 change: 0 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading