1111#   limitations under the License. 
1212"""This module provides generic tools for classes in ops/""" 
1313from  builtins  import  map , zip 
14- import  json 
14+ import  marshal 
1515import  os 
16- from  ast  import  literal_eval 
1716
1817import  numpy 
1918import  sympy 
3736from  openfermion .transforms .opconversions .term_reordering  import  normal_ordered 
3837
3938
39+ # Maximum size allowed for data files read by load_operator(). This is a (weak) safety 
40+ # measure against corrupted or insecure files. 
41+ _MAX_TEXT_OPERATOR_DATA  =  5  *  1024  *  1024 
42+ _MAX_BINARY_OPERATOR_DATA  =  1024  *  1024 
43+ 
44+ 
4045class  OperatorUtilsError (Exception ):
4146    pass 
4247
@@ -247,27 +252,45 @@ def get_file_path(file_name, data_directory):
247252
248253
249254def  load_operator (file_name = None , data_directory = None , plain_text = False ):
250-     """Load FermionOperator or QubitOperator  from file. 
255+     """Load an operator (such as a FermionOperator)  from a  file. 
251256
252257    Args: 
253-         file_name: The name of the saved  file. 
258+         file_name: The name of the data  file to read . 
254259        data_directory: Optional data directory to change from default data 
255-                          directory specified in config file. 
260+             directory specified in config file. 
256261        plain_text: Whether the input file is plain text 
257262
258263    Returns: 
259264        operator: The stored FermionOperator, BosonOperator, 
260-             QuadOperator, or QubitOperator 
265+             QuadOperator, or QubitOperator.  
261266
262267    Raises: 
263268        TypeError: Operator of invalid type. 
269+         ValueError: If the file is larger than the maximum allowed. 
270+         ValueError: If the file content is not as expected or loading fails. 
271+         IOError: If the file cannot be opened. 
272+ 
273+     Warning: 
274+         Loading from binary files (plain_text=False) uses the Python 'marshal' 
275+         module, which is not secure against untrusted or maliciously crafted 
276+         data. Only load binary operator files from sources that you trust. 
277+         Prefer using the plain_text format for data from untrusted sources. 
264278    """ 
279+ 
265280    file_path  =  get_file_path (file_name , data_directory )
266281
282+     operator_type  =  None 
283+     operator_terms  =  None 
284+ 
267285    if  plain_text :
268286        with  open (file_path , 'r' ) as  f :
269-             data  =  f .read ()
270-             operator_type , operator_terms  =  data .split (":\n " )
287+             data  =  f .read (_MAX_TEXT_OPERATOR_DATA )
288+         try :
289+             operator_type , operator_terms  =  data .split (":\n " , 1 )
290+         except  ValueError :
291+             raise  ValueError (
292+                 "Invalid format in plain-text data file {file_path}: "  "expected 'TYPE:\\ nTERMS'" 
293+             )
271294
272295        if  operator_type  ==  'FermionOperator' :
273296            operator  =  FermionOperator (operator_terms )
@@ -278,15 +301,24 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
278301        elif  operator_type  ==  'QuadOperator' :
279302            operator  =  QuadOperator (operator_terms )
280303        else :
281-             raise  TypeError ('Operator of invalid type.' )
304+             raise  TypeError (
305+                 f"Invalid operator type '{ operator_type }  
306+                 f"found in plain-text data file '{ file_path }  
307+             )
282308    else :
283-         with  open (file_path , 'r' ) as  f :
284-             data  =  json .load (f )
285-         operator_type , serializable_terms  =  data 
286-         operator_terms  =  {
287-             literal_eval (key ): complex (value [0 ], value [1 ])
288-             for  key , value  in  serializable_terms .items ()
289-         }
309+         # marshal.load() doesn't have a size parameter, so we test it ourselves. 
310+         if  os .path .getsize (file_path ) >  _MAX_BINARY_OPERATOR_DATA :
311+             raise  ValueError (
312+                 f"Size of { file_path }  
313+                 f"({ _MAX_BINARY_OPERATOR_DATA }  
314+             )
315+         try :
316+             with  open (file_path , 'rb' ) as  f :
317+                 raw_data  =  marshal .load (f )
318+         except  Exception  as  e :
319+             raise  ValueError (f"Failed to load marshaled data from { file_path } { e }  )
320+ 
321+         operator_type , operator_terms  =  _validate_operator_data (raw_data )
290322
291323        if  operator_type  ==  'FermionOperator' :
292324            operator  =  FermionOperator ()
@@ -313,17 +345,17 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
313345def  save_operator (
314346    operator , file_name = None , data_directory = None , allow_overwrite = False , plain_text = False 
315347):
316-     """Save FermionOperator or QubitOperator to  file. 
348+     """Save an operator (such as a FermionOperator) to a  file. 
317349
318350    Args: 
319351        operator: An instance of FermionOperator, BosonOperator, 
320352            or QubitOperator. 
321353        file_name: The name of the saved file. 
322354        data_directory: Optional data directory to change from default data 
323-                          directory specified in config file. 
355+             directory specified in config file. 
324356        allow_overwrite: Whether to allow files to be overwritten. 
325357        plain_text: Whether the operator should be saved to a 
326-                          plain-text format for manual analysis 
358+             plain-text format for manual analysis.  
327359
328360    Raises: 
329361        OperatorUtilsError: Not saved, file already exists. 
@@ -360,6 +392,55 @@ def save_operator(
360392            f .write (operator_type  +  ":\n "  +  str (operator ))
361393    else :
362394        tm  =  operator .terms 
363-         serializable_terms  =  {str (key ): (value .real , value .imag ) for  key , value  in  tm .items ()}
364-         with  open (file_path , 'w' ) as  f :
365-             json .dump ((operator_type , serializable_terms ), f )
395+         with  open (file_path , 'wb' ) as  f :
396+             marshal .dump ((operator_type , dict (zip (tm .keys (), map (complex , tm .values ())))), f )
397+ 
398+ 
399+ def  _validate_operator_data (raw_data ):
400+     """Validates the structure and types of data loaded using marshal. 
401+ 
402+     The file is expected to contain a tuple of (type, data), where the 
403+     "type" is one of the currently-supported operators, and "data" is a dict. 
404+ 
405+     Args: 
406+         raw_data: text or binary data. 
407+ 
408+     Returns: 
409+         tuple(str, dict) where the 0th element is the name of the operator 
410+             type (e.g., 'FermionOperator') and the dict is the operator data. 
411+ 
412+     Raises: 
413+         TypeError: raw_data did not contain a tuple of length 2. 
414+         TypeError: the first element of the tuple is not a string. 
415+         TypeError: the second element of the tuple is not a dict. 
416+         TypeError: the given operator type is not supported. 
417+     """ 
418+ 
419+     if  not  isinstance (raw_data , tuple ) or  len (raw_data ) !=  2 :
420+         raise  TypeError (
421+             f"Invalid marshaled structure: Expected a tuple " 
422+             "of length 2, but got {type(raw_data)} instead." 
423+         )
424+ 
425+     operator_type , operator_terms  =  raw_data 
426+ 
427+     if  not  isinstance (operator_type , str ):
428+         raise  TypeError (
429+             f"Invalid type for operator_type: Expected str but " 
430+             "got type {type(operator_type)} instead." 
431+         )
432+ 
433+     allowed  =  {'FermionOperator' , 'BosonOperator' , 'QubitOperator' , 'QuadOperator' }
434+     if  operator_type  not  in allowed :
435+         raise  TypeError (
436+             f"Operator type '{ operator_type }  
437+             "The operator must be one of {allowed}." 
438+         )
439+ 
440+     if  not  isinstance (operator_terms , dict ):
441+         raise  TypeError (
442+             f"Invalid type for operator_terms: Expected dict " 
443+             "but got type {type(operator_terms)} instead." 
444+         )
445+ 
446+     return  operator_type , operator_terms 
0 commit comments