Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 98 additions & 32 deletions supervisor/api/backups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Backups RESTful API."""

from __future__ import annotations

import asyncio
from collections.abc import Callable
import errno
Expand All @@ -14,7 +16,7 @@
import voluptuous as vol

from ..backups.backup import Backup
from ..backups.const import LOCATION_CLOUD_BACKUP
from ..backups.const import LOCATION_CLOUD_BACKUP, LOCATION_TYPE
from ..backups.validate import ALL_FOLDERS, FOLDER_HOMEASSISTANT, days_until_stale
from ..const import (
ATTR_ADDONS,
Expand All @@ -23,7 +25,7 @@
ATTR_CONTENT,
ATTR_DATE,
ATTR_DAYS_UNTIL_STALE,
ATTR_FILENAME,
ATTR_EXTRA,
ATTR_FOLDERS,
ATTR_HOMEASSISTANT,
ATTR_HOMEASSISTANT_EXCLUDE_DATABASE,
Expand All @@ -48,7 +50,12 @@
from ..jobs import JobSchedulerOptions
from ..mounts.const import MountUsage
from ..resolution.const import UnhealthyReason
from .const import ATTR_BACKGROUND, ATTR_LOCATIONS, CONTENT_TYPE_TAR
from .const import (
ATTR_ADDITIONAL_LOCATIONS,
ATTR_BACKGROUND,
ATTR_LOCATIONS,
CONTENT_TYPE_TAR,
)
from .utils import api_process, api_validate

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -60,6 +67,14 @@
# Remove: 2022.08
_ALL_FOLDERS = ALL_FOLDERS + [FOLDER_HOMEASSISTANT]


def _ensure_list(item: Any) -> list:
"""Ensure value is a list."""
if not isinstance(item, list):
return [item]
return item


# pylint: disable=no-value-for-parameter
SCHEMA_RESTORE_FULL = vol.Schema(
{
Expand All @@ -81,9 +96,12 @@
vol.Optional(ATTR_NAME): str,
vol.Optional(ATTR_PASSWORD): vol.Maybe(str),
vol.Optional(ATTR_COMPRESSED): vol.Maybe(vol.Boolean()),
vol.Optional(ATTR_LOCATION): vol.Maybe(str),
vol.Optional(ATTR_LOCATION): vol.All(
_ensure_list, [vol.Maybe(str)], vol.Unique()
),
vol.Optional(ATTR_HOMEASSISTANT_EXCLUDE_DATABASE): vol.Boolean(),
vol.Optional(ATTR_BACKGROUND, default=False): vol.Boolean(),
vol.Optional(ATTR_EXTRA): dict,
}
)

Expand All @@ -106,12 +124,6 @@
vol.Optional(ATTR_TIMEOUT): vol.All(int, vol.Range(min=1)),
}
)
SCHEMA_RELOAD = vol.Schema(
{
vol.Inclusive(ATTR_LOCATION, "file"): vol.Maybe(str),
vol.Inclusive(ATTR_FILENAME, "file"): vol.Match(RE_BACKUP_FILENAME),
}
)


class APIBackups(CoreSysAttributes):
Expand Down Expand Up @@ -177,13 +189,10 @@ async def options(self, request):
self.sys_backups.save_data()

@api_process
async def reload(self, request: web.Request):
async def reload(self, _):
"""Reload backup list."""
body = await api_validate(SCHEMA_RELOAD, request)
self._validate_cloud_backup_location(request, body.get(ATTR_LOCATION))
backup = self._location_to_mount(body)

return await asyncio.shield(self.sys_backups.reload(**backup))
await asyncio.shield(self.sys_backups.reload())
return True

@api_process
async def backup_info(self, request):
Expand Down Expand Up @@ -217,27 +226,35 @@ async def backup_info(self, request):
ATTR_REPOSITORIES: backup.repositories,
ATTR_FOLDERS: backup.folders,
ATTR_HOMEASSISTANT_EXCLUDE_DATABASE: backup.homeassistant_exclude_database,
ATTR_EXTRA: backup.extra,
}

def _location_to_mount(self, body: dict[str, Any]) -> dict[str, Any]:
"""Change location field to mount if necessary."""
if not body.get(ATTR_LOCATION) or body[ATTR_LOCATION] == LOCATION_CLOUD_BACKUP:
return body
def _location_to_mount(self, location: str | None) -> LOCATION_TYPE:
"""Convert a single location to a mount if possible."""
if not location or location == LOCATION_CLOUD_BACKUP:
return location

body[ATTR_LOCATION] = self.sys_mounts.get(body[ATTR_LOCATION])
if body[ATTR_LOCATION].usage != MountUsage.BACKUP:
mount = self.sys_mounts.get(location)
if mount.usage != MountUsage.BACKUP:
raise APIError(
f"Mount {body[ATTR_LOCATION].name} is not used for backups, cannot backup to there"
f"Mount {mount.name} is not used for backups, cannot backup to there"
)

return mount

def _location_field_to_mount(self, body: dict[str, Any]) -> dict[str, Any]:
"""Change location field to mount if necessary."""
body[ATTR_LOCATION] = self._location_to_mount(body.get(ATTR_LOCATION))
return body

def _validate_cloud_backup_location(
self, request: web.Request, location: str | None
self, request: web.Request, location: list[str | None] | str | None
) -> None:
"""Cloud backup location is only available to Home Assistant."""
if not isinstance(location, list):
location = [location]
if (
location == LOCATION_CLOUD_BACKUP
LOCATION_CLOUD_BACKUP in location
and request.get(REQUEST_FROM) != self.sys_homeassistant
):
raise APIForbidden(
Expand Down Expand Up @@ -278,10 +295,22 @@ async def release_on_freeze(new_state: CoreState):
async def backup_full(self, request: web.Request):
"""Create full backup."""
body = await api_validate(SCHEMA_BACKUP_FULL, request)
self._validate_cloud_backup_location(request, body.get(ATTR_LOCATION))
locations: list[LOCATION_TYPE] | None = None

if ATTR_LOCATION in body:
location_names: list[str | None] = body.pop(ATTR_LOCATION)
self._validate_cloud_backup_location(request, location_names)

locations = [
self._location_to_mount(location) for location in location_names
]
body[ATTR_LOCATION] = locations.pop(0)
if locations:
body[ATTR_ADDITIONAL_LOCATIONS] = locations

background = body.pop(ATTR_BACKGROUND)
backup_task, job_id = await self._background_backup_task(
self.sys_backups.do_backup_full, **self._location_to_mount(body)
self.sys_backups.do_backup_full, **body
)

if background and not backup_task.done():
Expand All @@ -299,10 +328,22 @@ async def backup_full(self, request: web.Request):
async def backup_partial(self, request: web.Request):
"""Create a partial backup."""
body = await api_validate(SCHEMA_BACKUP_PARTIAL, request)
self._validate_cloud_backup_location(request, body.get(ATTR_LOCATION))
locations: list[LOCATION_TYPE] | None = None

if ATTR_LOCATION in body:
location_names: list[str | None] = body.pop(ATTR_LOCATION)
self._validate_cloud_backup_location(request, location_names)

locations = [
self._location_to_mount(location) for location in location_names
]
body[ATTR_LOCATION] = locations.pop(0)
if locations:
body[ATTR_ADDITIONAL_LOCATIONS] = locations

background = body.pop(ATTR_BACKGROUND)
backup_task, job_id = await self._background_backup_task(
self.sys_backups.do_backup_partial, **self._location_to_mount(body)
self.sys_backups.do_backup_partial, **body
)

if background and not backup_task.done():
Expand Down Expand Up @@ -370,9 +411,11 @@ async def remove(self, request: web.Request):
self._validate_cloud_backup_location(request, backup.location)
return self.sys_backups.remove(backup)

@api_process
async def download(self, request: web.Request):
"""Download a backup file."""
backup = self._extract_slug(request)
self._validate_cloud_backup_location(request, backup.location)

_LOGGER.info("Downloading backup %s", backup.slug)
response = web.FileResponse(backup.tarfile)
Expand All @@ -385,7 +428,23 @@ async def download(self, request: web.Request):
@api_process
async def upload(self, request: web.Request):
"""Upload a backup file."""
with TemporaryDirectory(dir=str(self.sys_config.path_tmp)) as temp_dir:
location: LOCATION_TYPE = None
locations: list[LOCATION_TYPE] | None = None
tmp_path = self.sys_config.path_tmp
if ATTR_LOCATION in request.query:
location_names: list[str] = request.query.getall(ATTR_LOCATION)
self._validate_cloud_backup_location(request, location_names)
# Convert empty string to None if necessary
locations = [
self._location_to_mount(location) if location else None
for location in location_names
]
location = locations.pop(0)

if location and location != LOCATION_CLOUD_BACKUP:
tmp_path = location.local_where

with TemporaryDirectory(dir=tmp_path.as_posix()) as temp_dir:
tar_file = Path(temp_dir, "backup.tar")
reader = await request.multipart()
contents = await reader.next()
Expand All @@ -398,15 +457,22 @@ async def upload(self, request: web.Request):
backup.write(chunk)

except OSError as err:
if err.errno == errno.EBADMSG:
if err.errno == errno.EBADMSG and location in {
LOCATION_CLOUD_BACKUP,
None,
}:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE
_LOGGER.error("Can't write new backup file: %s", err)
return False

except asyncio.CancelledError:
return False

backup = await asyncio.shield(self.sys_backups.import_backup(tar_file))
backup = await asyncio.shield(
self.sys_backups.import_backup(
tar_file, location=location, additional_locations=locations
)
)

if backup:
return {ATTR_SLUG: backup.slug}
Expand Down
1 change: 1 addition & 0 deletions supervisor/api/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

COOKIE_INGRESS = "ingress_session"

ATTR_ADDITIONAL_LOCATIONS = "additional_locations"
ATTR_AGENT_VERSION = "agent_version"
ATTR_APPARMOR_VERSION = "apparmor_version"
ATTR_ATTRIBUTES = "attributes"
Expand Down
Loading
Loading