1212"""
1313
1414import glob
15+ import hashlib
1516import json
17+ import os
18+ import pickle
1619import random
1720import re
21+ import shutil
1822import sys
23+ import tempfile
1924import time
2025from enum import Enum
2126from os .path import basename , dirname , exists
@@ -106,10 +111,12 @@ def __init__(
106111 dataset = None ,
107112 table = None ,
108113 billing_project = None ,
114+ use_cache = True ,
109115 ):
110116 """Instantiate DryRun class."""
111117 self .sqlfile = sqlfile
112118 self .content = content
119+ self .use_cache = use_cache
113120 self .query_parameters = query_parameters
114121 self .strip_dml = strip_dml
115122 self .use_cloud_function = use_cloud_function
@@ -192,6 +199,17 @@ def skipped_files(sql_dir=ConfigLoader.get("default", "sql_dir")) -> Set[str]:
192199
193200 return skip_files
194201
202+ @staticmethod
203+ def clear_cache ():
204+ """Clear dry run cache directory."""
205+ cache_dir = Path (tempfile .gettempdir ()) / "bigquery_etl_dryrun_cache"
206+ if cache_dir .exists ():
207+ try :
208+ shutil .rmtree (cache_dir )
209+ print (f"Cleared dry run cache at { cache_dir } " )
210+ except OSError as e :
211+ print (f"Warning: Failed to clear dry run cache: { e } " )
212+
195213 def skip (self ):
196214 """Determine if dry run should be skipped."""
197215 return self .respect_skip and self .sqlfile in self .skipped_files (
@@ -225,6 +243,108 @@ def get_sql(self):
225243
226244 return sql
227245
246+ def _get_cache_key (self , sql ):
247+ """Generate cache key based on SQL content and other parameters."""
248+ cache_input = f"{ sql } |{ self .project } |{ self .dataset } |{ self .table } "
249+ return hashlib .sha256 (cache_input .encode ()).hexdigest ()
250+
251+ @staticmethod
252+ def _get_cache_dir ():
253+ """Get the cache directory path."""
254+ cache_dir = Path (tempfile .gettempdir ()) / "bigquery_etl_dryrun_cache"
255+ cache_dir .mkdir (parents = True , exist_ok = True )
256+ return cache_dir
257+
258+ def _read_cache_file (self , cache_file , ttl_seconds ):
259+ """Read and return cached data from a pickle file with TTL check."""
260+ try :
261+ if not cache_file .exists ():
262+ return None
263+
264+ # check if cache is expired
265+ file_age = time .time () - cache_file .stat ().st_mtime
266+ if file_age > ttl_seconds :
267+ try :
268+ cache_file .unlink ()
269+ except OSError :
270+ pass
271+ return None
272+
273+ cached_data = pickle .loads (cache_file .read_bytes ())
274+ return cached_data
275+ except (pickle .PickleError , EOFError , OSError , FileNotFoundError ) as e :
276+ print (f"[CACHE] Failed to load { cache_file } : { e } " )
277+ try :
278+ if cache_file .exists ():
279+ cache_file .unlink ()
280+ except OSError :
281+ pass
282+ return None
283+
284+ @staticmethod
285+ def _write_cache_file (cache_file , data ):
286+ """Write data to a cache file using atomic write."""
287+ try :
288+ # write to temporary file first, then atomically rename
289+ # this prevents race conditions where readers get partial files
290+ # include random bytes to handle thread pool scenarios where threads share same PID
291+ temp_file = Path (
292+ str (cache_file ) + f".tmp.{ os .getpid ()} .{ os .urandom (4 ).hex ()} "
293+ )
294+ with open (temp_file , "wb" ) as f :
295+ pickle .dump (data , f )
296+ f .flush ()
297+ os .fsync (f .fileno ()) # Ensure data is written to disk
298+
299+ temp_file .replace (cache_file )
300+ except (pickle .PickleError , OSError ) as e :
301+ print (f"[CACHE] Failed to save { cache_file } : { e } " )
302+ try :
303+ if "temp_file" in locals () and temp_file .exists ():
304+ temp_file .unlink ()
305+ except (OSError , NameError ):
306+ pass
307+
308+ def _get_cached_result (self , cache_key , ttl_seconds = None ):
309+ """Load cached dry run result from disk."""
310+ if ttl_seconds is None :
311+ ttl_seconds = ConfigLoader .get ("dry_run" , "cache_ttl_seconds" , fallback = 900 )
312+
313+ cache_file = self ._get_cache_dir () / f"dryrun_{ cache_key } .pkl"
314+ return self ._read_cache_file (cache_file , ttl_seconds )
315+
316+ def _save_cached_result (self , cache_key , result ):
317+ """Save dry run result to disk cache using atomic write."""
318+ cache_file = self ._get_cache_dir () / f"dryrun_{ cache_key } .pkl"
319+ self ._write_cache_file (cache_file , result )
320+
321+ # save table metadata separately if present
322+ if (
323+ result
324+ and "tableMetadata" in result
325+ and self .project
326+ and self .dataset
327+ and self .table
328+ ):
329+ table_identifier = f"{ self .project } .{ self .dataset } .{ self .table } "
330+ self ._save_cached_table_metadata (table_identifier , result ["tableMetadata" ])
331+
332+ def _get_cached_table_metadata (self , table_identifier , ttl_seconds = None ):
333+ """Load cached table metadata from disk based on table identifier."""
334+ if ttl_seconds is None :
335+ ttl_seconds = ConfigLoader .get ("dry_run" , "cache_ttl_seconds" , fallback = 900 )
336+
337+ # table identifier as cache key
338+ table_cache_key = hashlib .sha256 (table_identifier .encode ()).hexdigest ()
339+ cache_file = self ._get_cache_dir () / f"table_metadata_{ table_cache_key } .pkl"
340+ return self ._read_cache_file (cache_file , ttl_seconds )
341+
342+ def _save_cached_table_metadata (self , table_identifier , metadata ):
343+ """Save table metadata to disk cache using atomic write."""
344+ table_cache_key = hashlib .sha256 (table_identifier .encode ()).hexdigest ()
345+ cache_file = self ._get_cache_dir () / f"table_metadata_{ table_cache_key } .pkl"
346+ self ._write_cache_file (cache_file , metadata )
347+
228348 @cached_property
229349 def dry_run_result (self ):
230350 """Dry run the provided SQL file."""
@@ -233,6 +353,14 @@ def dry_run_result(self):
233353 else :
234354 sql = self .get_sql ()
235355
356+ # check cache first (if caching is enabled)
357+ if sql is not None and self .use_cache :
358+ cache_key = self ._get_cache_key (sql )
359+ cached_result = self ._get_cached_result (cache_key )
360+ if cached_result is not None :
361+ self .dry_run_duration = 0 # Cached result, no actual dry run
362+ return cached_result
363+
236364 query_parameters = []
237365 if self .query_parameters :
238366 for parameter_name , parameter_type in self .query_parameters .items ():
@@ -351,6 +479,12 @@ def dry_run_result(self):
351479 }
352480
353481 self .dry_run_duration = time .time () - start_time
482+
483+ # Save to cache (if caching is enabled and result is valid)
484+ # Don't cache errors to allow retries
485+ if self .use_cache and result .get ("valid" ):
486+ self ._save_cached_result (cache_key , result )
487+
354488 return result
355489
356490 except Exception as e :
@@ -476,6 +610,13 @@ def get_table_schema(self):
476610 ):
477611 return self .dry_run_result ["tableMetadata" ]["schema" ]
478612
613+ # Check if table metadata is cached (if caching is enabled)
614+ if self .use_cache and self .project and self .dataset and self .table :
615+ table_identifier = f"{ self .project } .{ self .dataset } .{ self .table } "
616+ cached_metadata = self ._get_cached_table_metadata (table_identifier )
617+ if cached_metadata :
618+ return cached_metadata ["schema" ]
619+
479620 return []
480621
481622 def get_dataset_labels (self ):
@@ -565,6 +706,13 @@ def validate_schema(self):
565706 return True
566707
567708 query_file_path = Path (self .sqlfile )
709+ table_name = query_file_path .parent .name
710+ dataset_name = query_file_path .parent .parent .name
711+ project_name = query_file_path .parent .parent .parent .name
712+ self .project = project_name
713+ self .dataset = dataset_name
714+ self .table = table_name
715+
568716 query_schema = Schema .from_json (self .get_schema ())
569717 if self .errors ():
570718 # ignore file when there are errors that self.get_schema() did not raise
@@ -576,26 +724,7 @@ def validate_schema(self):
576724 click .echo (f"No schema file defined for { query_file_path } " , err = True )
577725 return True
578726
579- table_name = query_file_path .parent .name
580- dataset_name = query_file_path .parent .parent .name
581- project_name = query_file_path .parent .parent .parent .name
582-
583- partitioned_by = None
584- if (
585- self .metadata
586- and self .metadata .bigquery
587- and self .metadata .bigquery .time_partitioning
588- ):
589- partitioned_by = self .metadata .bigquery .time_partitioning .field
590-
591- table_schema = Schema .for_table (
592- project_name ,
593- dataset_name ,
594- table_name ,
595- client = self .client ,
596- id_token = self .id_token ,
597- partitioned_by = partitioned_by ,
598- )
727+ table_schema = Schema .from_json (self .get_table_schema ())
599728
600729 # This check relies on the new schema being deployed to prod
601730 if not query_schema .compatible (table_schema ):
0 commit comments