Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ RemoteShell provides the following MCP tools for remote server management:

**Example**: *"Show me which servers I have configured"* → Returns list of all saved servers with online status

### 💾 `save_server(connection_id, host, user, auth_type, credential)`
### 💾 `save_server(connection_id, host, user, auth_type, credential, port)`

**Purpose**: Create or update a server profile with authentication credentials.

Expand All @@ -111,6 +111,7 @@ RemoteShell provides the following MCP tools for remote server management:
- `credential`:
- For `password`: Plain text password string
- For `private_key`: File path (e.g., `~/.ssh/id_rsa`) or PEM key content
- `port`: SSH port (optional; defaults to 22 and keeps the existing saved port if omitted)

**When to use**:
- Adding a new server configuration
Expand Down
24 changes: 16 additions & 8 deletions src/remoteshell_mcp/host_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,35 @@ def upsert(
connection_id: str,
host: str,
user: str,
port: Optional[int] = None,
auth_type: AuthType,
credential: str,
) -> ServerConfig:
data = self._load_raw()
servers: Dict[str, Any] = data.get("servers", {})
existing = servers.get(connection_id)
last_connected = None
existing_port: int = 22
if isinstance(existing, dict):
last_connected = existing.get("last_connected")
try:
existing_port = int(existing.get("port", 22))
except Exception:
existing_port = 22

chosen_port = existing_port if port is None else int(port)

cfg = ServerConfig(
connection_id=connection_id,
host=host,
user=user,
port=22,
port=chosen_port,
auth_type=auth_type,
password=credential if auth_type == "password" else None,
private_key=credential if auth_type == "private_key" else None,
)
cfg.validate()

data = self._load_raw()
servers: Dict[str, Any] = data.get("servers", {})
existing = servers.get(connection_id)
last_connected = None
if isinstance(existing, dict):
last_connected = existing.get("last_connected")

cfg.last_connected = last_connected
payload = cfg.to_dict()
payload.pop("connection_id", None) # connection_id is stored as the dict key
Expand Down
15 changes: 12 additions & 3 deletions src/remoteshell_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,29 @@ def list_servers() -> Dict[str, Any]:
"Persist (create or update) a server connection profile in the local host store.\n\n"
"When to use: When the user provides new SSH details, or after an auth_failed error to update credentials.\n"
"When NOT to use: Do not ask for credentials again if they are already saved and still valid.\n\n"
'Example: save_server(connection_id="srv1", host="1.2.3.4", user="root", auth_type="password", credential="<password>")'
'Example: save_server(connection_id="srv1", host="1.2.3.4", user="root", auth_type="password", credential="<password>", port=2222)'
)
)
def save_server(
connection_id: Annotated[str, Field(description="Unique identifier for this server connection")],
host: Annotated[str, Field(description="Server hostname or IP address")],
user: Annotated[str, Field(description="SSH username")],
auth_type: Annotated[str, Field(description="Authentication method: 'password', 'ssh_key', or 'ssh_agent'")],
credential: Annotated[str, Field(description="Password for 'password' auth, or path to private key for 'ssh_key' auth (empty for 'ssh_agent')")],
auth_type: Annotated[str, Field(description="Authentication method: 'password' or 'private_key'")],
credential: Annotated[str, Field(description="Password for 'password' auth, or path/PEM text for 'private_key' auth")],
port: Annotated[
Optional[int],
Field(
description="SSH port. Defaults to 22. If omitted, keeps the existing saved port (if any)."
),
] = None,
) -> Dict[str, Any]:
manager = _manager()
try:
cfg = manager.host_store.upsert(
connection_id=connection_id,
host=host,
user=user,
port=port,
auth_type=auth_type, # type: ignore[arg-type]
credential=credential,
)
Expand Down Expand Up @@ -243,6 +250,7 @@ def upload_file(
return {
"success": bool(result.get("success")),
"connection_id": connection_id,
"port": getattr(client, "port", None),
"local_path": result.get("local_path", chosen_local_path),
"remote_path": result.get("remote_path", remote_path),
"size": result.get("size"),
Expand Down Expand Up @@ -280,6 +288,7 @@ def download_file(
return {
"success": bool(result.get("success")),
"connection_id": connection_id,
"port": getattr(client, "port", None),
"remote_path": result.get("remote_path", remote_path),
"local_path": result.get("local_path", chosen_local_path),
"size": result.get("size"),
Expand Down
2 changes: 1 addition & 1 deletion task.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
1. 全面按照以下idea优化改造工具,不要考虑兼容性,直接移除不需要的工具,考虑在描述中加入工具的基本描述,什么时候使用这个工具,什么时候不应该使用,使用的示例。
工具名称,参数 (Parameters),LLM 侧描述 (Description)
list_servers,无,【用途】 获取本地保存的所有远程服务器配置清单。包含 ID、主机名、用户名及其在线状态。【何时使用】 当用户提到“连接服务器”、“查看机器”或未指定目标 ID 时,首先调用此工具查看可用资源。【示例】 “查看我有哪些服务器。”
save_server,"connection_id, host, user, auth_type, credential",【用途】 持久化保存服务器连接信息到本地加密库。支持 password 或 private_key。【何时使用】 当用户提供新的服务器信息,或由于现有凭据失效(AUTH_FAILED)需要更新时调用。【注意事项】 成功保存后,后续操作仅需引用 connection_id。请勿在对话中重复索要已保存的信息。【示例】 “保存我的阿里云服务器,IP x.x.x.x,用户 root,密码 xxx。”
save_server,"connection_id, host, user, auth_type, credential, port",【用途】 持久化保存服务器连接信息到本地加密库。支持自定义 SSH 端口、password 或 private_key。【何时使用】 当用户提供新的服务器信息,或由于现有凭据失效(AUTH_FAILED)需要更新时调用。【注意事项】 成功保存后,后续操作仅需引用 connection_id。请勿在对话中重复索要已保存的信息。【示例】 “保存我的服务器,IP x.x.x.x,用户 root,端口 2222,密码 xxx。”
remove_server,connection_id,【用途】 从本地库中彻底删除指定的服务器配置。【何时使用】 仅当用户明确要求“忘记”、“删除”或“移除”某台机器的配置时使用。【注意事项】 操作不可逆,删除后需重新调用 save_server 才能再次连接。
execute_command,"connection_id, command","【用途】 在远程服务器上执行非交互式 Shell 命令并返回结果。【何时使用】 所有的状态查询(ls, top, df)、文件操作(cp, mv)或脚本运行。【不适用场景】 严禁执行需要实时交互的命令(如 vim, htop, 或需要手动确认 [Y/n] 的命令,除非使用了 -y 参数)。【注意事项】 如涉及敏感目录,请尝试使用 sudo 前缀。【示例】 execute_command(connection_id=""srv1"", command=""df -h"")"
upload_file,"connection_id, local_path, remote_path","【用途】 将本地计算机的文件安全传输到远程服务器。【何时使用】 部署配置文件、上传脚本或代码包到远程。【注意事项】 确保远程目标目录存在写权限。如果 remote_path 仅是一个目录,文件名将保持与本地一致。【示例】 upload_file(connection_id=""srv1"", local_path=""./config.yaml"", remote_path=""/etc/app/"")"
Expand Down
2 changes: 2 additions & 0 deletions tests/test_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_connection_manager_connects_and_updates_last_connected(tmp_path: Path,
connection_id="srv1",
host="1.2.3.4",
user="root",
port=2222,
auth_type="password",
credential="secret",
)
Expand All @@ -28,6 +29,7 @@ def test_connection_manager_connects_and_updates_last_connected(tmp_path: Path,
manager = ConnectionManager(store)
client = manager.get_or_create_connection("srv1")
assert isinstance(client, RemoteSSHClient)
assert client.port == 2222

cfg = store.get("srv1")
assert cfg is not None
Expand Down
27 changes: 27 additions & 0 deletions tests/test_host_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,39 @@ def test_host_store_upsert_list_get_remove(tmp_path: Path):
credential="secret",
)
assert cfg.connection_id == "srv1"
assert cfg.port == 22
assert cfg.last_connected is None

cfg2 = store.get("srv1")
assert cfg2 is not None
assert cfg2.host == "1.2.3.4"
assert cfg2.auth_type == "password"
assert cfg2.port == 22

# Can set a custom SSH port.
store.upsert(
connection_id="srv1",
host="1.2.3.4",
user="root",
port=2222,
auth_type="password",
credential="secret",
)
cfg_port = store.get("srv1")
assert cfg_port is not None
assert cfg_port.port == 2222

# Updating credentials without providing a port keeps the existing saved port.
store.upsert(
connection_id="srv1",
host="1.2.3.4",
user="root",
auth_type="password",
credential="secret2",
)
cfg_keep = store.get("srv1")
assert cfg_keep is not None
assert cfg_keep.port == 2222

listed = store.list()
assert [c.connection_id for c in listed] == ["srv1"]
Expand Down