11"""Backups RESTful API.""" 
22
3+ from  __future__ import  annotations 
4+ 
35import  asyncio 
46from  collections .abc  import  Callable 
57import  errno 
68import  logging 
79from  pathlib  import  Path 
810import  re 
911from  tempfile  import  TemporaryDirectory 
10- from  typing  import  Any 
12+ from  typing  import  Any ,  Literal 
1113
1214from  aiohttp  import  web 
1315from  aiohttp .hdrs  import  CONTENT_DISPOSITION 
2325    ATTR_CONTENT ,
2426    ATTR_DATE ,
2527    ATTR_DAYS_UNTIL_STALE ,
26-     ATTR_FILENAME ,
28+     ATTR_EXTRA ,
2729    ATTR_FOLDERS ,
2830    ATTR_HOMEASSISTANT ,
2931    ATTR_HOMEASSISTANT_EXCLUDE_DATABASE ,
4749from  ..exceptions  import  APIError , APIForbidden 
4850from  ..jobs  import  JobSchedulerOptions 
4951from  ..mounts .const  import  MountUsage 
52+ from  ..mounts .mount  import  Mount 
5053from  ..resolution .const  import  UnhealthyReason 
51- from  .const  import  ATTR_BACKGROUND , ATTR_LOCATIONS , CONTENT_TYPE_TAR 
54+ from  .const  import  (
55+     ATTR_ADDITIONAL_LOCATIONS ,
56+     ATTR_BACKGROUND ,
57+     ATTR_LOCATIONS ,
58+     CONTENT_TYPE_TAR ,
59+ )
5260from  .utils  import  api_process , api_validate 
5361
5462_LOGGER : logging .Logger  =  logging .getLogger (__name__ )
6068# Remove: 2022.08 
6169_ALL_FOLDERS  =  ALL_FOLDERS  +  [FOLDER_HOMEASSISTANT ]
6270
71+ 
72+ def  _ensure_list (item : Any ) ->  list :
73+     """Ensure value is a list.""" 
74+     if  not  isinstance (item , list ):
75+         return  [item ]
76+     return  item 
77+ 
78+ 
6379# pylint: disable=no-value-for-parameter 
6480SCHEMA_RESTORE_FULL  =  vol .Schema (
6581    {
8197        vol .Optional (ATTR_NAME ): str ,
8298        vol .Optional (ATTR_PASSWORD ): vol .Maybe (str ),
8399        vol .Optional (ATTR_COMPRESSED ): vol .Maybe (vol .Boolean ()),
84-         vol .Optional (ATTR_LOCATION ): vol .Maybe (str ),
100+         vol .Optional (ATTR_LOCATION ): vol .All (
101+             _ensure_list , [vol .Maybe (str )], vol .Unique ()
102+         ),
85103        vol .Optional (ATTR_HOMEASSISTANT_EXCLUDE_DATABASE ): vol .Boolean (),
86104        vol .Optional (ATTR_BACKGROUND , default = False ): vol .Boolean (),
105+         vol .Optional (ATTR_EXTRA ): dict ,
87106    }
88107)
89108
106125        vol .Optional (ATTR_TIMEOUT ): vol .All (int , vol .Range (min = 1 )),
107126    }
108127)
109- SCHEMA_RELOAD  =  vol .Schema (
110-     {
111-         vol .Inclusive (ATTR_LOCATION , "file" ): vol .Maybe (str ),
112-         vol .Inclusive (ATTR_FILENAME , "file" ): vol .Match (RE_BACKUP_FILENAME ),
113-     }
114- )
115128
116129
117130class  APIBackups (CoreSysAttributes ):
@@ -177,13 +190,10 @@ async def options(self, request):
177190        self .sys_backups .save_data ()
178191
179192    @api_process  
180-     async  def  reload (self , request :  web . Request ):
193+     async  def  reload (self , _ ):
181194        """Reload backup list.""" 
182-         body  =  await  api_validate (SCHEMA_RELOAD , request )
183-         self ._validate_cloud_backup_location (request , body .get (ATTR_LOCATION ))
184-         backup  =  self ._location_to_mount (body )
185- 
186-         return  await  asyncio .shield (self .sys_backups .reload (** backup ))
195+         await  asyncio .shield (self .sys_backups .reload ())
196+         return  True 
187197
188198    @api_process  
189199    async  def  backup_info (self , request ):
@@ -217,27 +227,37 @@ async def backup_info(self, request):
217227            ATTR_REPOSITORIES : backup .repositories ,
218228            ATTR_FOLDERS : backup .folders ,
219229            ATTR_HOMEASSISTANT_EXCLUDE_DATABASE : backup .homeassistant_exclude_database ,
230+             ATTR_EXTRA : backup .extra ,
220231        }
221232
222-     def  _location_to_mount (self , body : dict [str , Any ]) ->  dict [str , Any ]:
223-         """Change location field to mount if necessary.""" 
224-         if  not  body .get (ATTR_LOCATION ) or  body [ATTR_LOCATION ] ==  LOCATION_CLOUD_BACKUP :
225-             return  body 
233+     def  _location_to_mount (
234+         self , location : str  |  None 
235+     ) ->  Literal [LOCATION_CLOUD_BACKUP ] |  Mount  |  None :
236+         """Convert a single location to a mount if possible.""" 
237+         if  not  location  or  location  ==  LOCATION_CLOUD_BACKUP :
238+             return  location 
226239
227-         body [ ATTR_LOCATION ]  =  self .sys_mounts .get (body [ ATTR_LOCATION ] )
228-         if  body [ ATTR_LOCATION ] .usage  !=  MountUsage .BACKUP :
240+         mount  =  self .sys_mounts .get (location )
241+         if  mount .usage  !=  MountUsage .BACKUP :
229242            raise  APIError (
230-                 f"Mount { body [ ATTR_LOCATION ] .name }  
243+                 f"Mount { mount .name }  
231244            )
232245
246+         return  mount 
247+ 
248+     def  _location_field_to_mount (self , body : dict [str , Any ]) ->  dict [str , Any ]:
249+         """Change location field to mount if necessary.""" 
250+         body [ATTR_LOCATION ] =  self ._location_to_mount (body .get (ATTR_LOCATION ))
233251        return  body 
234252
235253    def  _validate_cloud_backup_location (
236-         self , request : web .Request , location : str  |  None 
254+         self , request : web .Request , location : list [ str   |   None ]  |   str  |  None 
237255    ) ->  None :
238256        """Cloud backup location is only available to Home Assistant.""" 
257+         if  not  isinstance (location , list ):
258+             location  =  [location ]
239259        if  (
240-             location   ==   LOCATION_CLOUD_BACKUP 
260+             LOCATION_CLOUD_BACKUP   in   location 
241261            and  request .get (REQUEST_FROM ) !=  self .sys_homeassistant 
242262        ):
243263            raise  APIForbidden (
@@ -278,10 +298,22 @@ async def release_on_freeze(new_state: CoreState):
278298    async  def  backup_full (self , request : web .Request ):
279299        """Create full backup.""" 
280300        body  =  await  api_validate (SCHEMA_BACKUP_FULL , request )
281-         self ._validate_cloud_backup_location (request , body .get (ATTR_LOCATION ))
301+         locations : list [Literal [LOCATION_CLOUD_BACKUP ] |  Mount  |  None ] |  None  =  None 
302+ 
303+         if  ATTR_LOCATION  in  body :
304+             location_names : list [str  |  None ] =  body .pop (ATTR_LOCATION )
305+             self ._validate_cloud_backup_location (request , location_names )
306+ 
307+             locations  =  [
308+                 self ._location_to_mount (location ) for  location  in  location_names 
309+             ]
310+             body [ATTR_LOCATION ] =  locations .pop (0 )
311+             if  locations :
312+                 body [ATTR_ADDITIONAL_LOCATIONS ] =  locations 
313+ 
282314        background  =  body .pop (ATTR_BACKGROUND )
283315        backup_task , job_id  =  await  self ._background_backup_task (
284-             self .sys_backups .do_backup_full , ** self . _location_to_mount ( body ) 
316+             self .sys_backups .do_backup_full , ** body 
285317        )
286318
287319        if  background  and  not  backup_task .done ():
@@ -299,10 +331,22 @@ async def backup_full(self, request: web.Request):
299331    async  def  backup_partial (self , request : web .Request ):
300332        """Create a partial backup.""" 
301333        body  =  await  api_validate (SCHEMA_BACKUP_PARTIAL , request )
302-         self ._validate_cloud_backup_location (request , body .get (ATTR_LOCATION ))
334+         locations : list [Literal [LOCATION_CLOUD_BACKUP ] |  Mount  |  None ] |  None  =  None 
335+ 
336+         if  ATTR_LOCATION  in  body :
337+             location_names : list [str  |  None ] =  body .pop (ATTR_LOCATION )
338+             self ._validate_cloud_backup_location (request , location_names )
339+ 
340+             locations  =  [
341+                 self ._location_to_mount (location ) for  location  in  location_names 
342+             ]
343+             body [ATTR_LOCATION ] =  locations .pop (0 )
344+             if  locations :
345+                 body [ATTR_ADDITIONAL_LOCATIONS ] =  locations 
346+ 
303347        background  =  body .pop (ATTR_BACKGROUND )
304348        backup_task , job_id  =  await  self ._background_backup_task (
305-             self .sys_backups .do_backup_partial , ** self . _location_to_mount ( body ) 
349+             self .sys_backups .do_backup_partial , ** body 
306350        )
307351
308352        if  background  and  not  backup_task .done ():
@@ -370,9 +414,11 @@ async def remove(self, request: web.Request):
370414        self ._validate_cloud_backup_location (request , backup .location )
371415        return  self .sys_backups .remove (backup )
372416
417+     @api_process  
373418    async  def  download (self , request : web .Request ):
374419        """Download a backup file.""" 
375420        backup  =  self ._extract_slug (request )
421+         self ._validate_cloud_backup_location (request , backup .location )
376422
377423        _LOGGER .info ("Downloading backup %s" , backup .slug )
378424        response  =  web .FileResponse (backup .tarfile )
@@ -385,7 +431,23 @@ async def download(self, request: web.Request):
385431    @api_process  
386432    async  def  upload (self , request : web .Request ):
387433        """Upload a backup file.""" 
388-         with  TemporaryDirectory (dir = str (self .sys_config .path_tmp )) as  temp_dir :
434+         location : Literal [LOCATION_CLOUD_BACKUP ] |  Mount  |  None  =  None 
435+         locations : list [Literal [LOCATION_CLOUD_BACKUP ] |  Mount  |  None ] |  None  =  None 
436+         tmp_path  =  self .sys_config .path_tmp 
437+         if  ATTR_LOCATION  in  request .query :
438+             location_names : list [str ] =  request .query .getall (ATTR_LOCATION )
439+             self ._validate_cloud_backup_location (request , location_names )
440+             # Convert empty string to None if necessary 
441+             locations  =  [
442+                 self ._location_to_mount (location ) if  location  else  None 
443+                 for  location  in  location_names 
444+             ]
445+             location  =  locations .pop (0 )
446+ 
447+             if  location  and  location  !=  LOCATION_CLOUD_BACKUP :
448+                 tmp_path  =  location .local_where 
449+ 
450+         with  TemporaryDirectory (dir = tmp_path .as_posix ()) as  temp_dir :
389451            tar_file  =  Path (temp_dir , "backup.tar" )
390452            reader  =  await  request .multipart ()
391453            contents  =  await  reader .next ()
@@ -398,15 +460,22 @@ async def upload(self, request: web.Request):
398460                        backup .write (chunk )
399461
400462            except  OSError  as  err :
401-                 if  err .errno  ==  errno .EBADMSG :
463+                 if  err .errno  ==  errno .EBADMSG  and  location  in  {
464+                     LOCATION_CLOUD_BACKUP ,
465+                     None ,
466+                 }:
402467                    self .sys_resolution .unhealthy  =  UnhealthyReason .OSERROR_BAD_MESSAGE 
403468                _LOGGER .error ("Can't write new backup file: %s" , err )
404469                return  False 
405470
406471            except  asyncio .CancelledError :
407472                return  False 
408473
409-             backup  =  await  asyncio .shield (self .sys_backups .import_backup (tar_file ))
474+             backup  =  await  asyncio .shield (
475+                 self .sys_backups .import_backup (
476+                     tar_file , location = location , additional_locations = locations 
477+                 )
478+             )
410479
411480        if  backup :
412481            return  {ATTR_SLUG : backup .slug }
0 commit comments