Skip to content

Commit

Permalink
restrict moving input files to source folders that contain "gradio" a…
Browse files Browse the repository at this point in the history
…s path element
  • Loading branch information
niwa2 committed Jan 26, 2025
1 parent cc4d389 commit ccd3777
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions modules/whisper/base_transcription_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,26 @@ def run(self,
total_elapsed_time = time.time() - start_time
return result, total_elapsed_time

def isGradioCachePath(self, in_path: str):
"""
Check if given path is a gradio chache path
Parameters
----------
in_path
the path to check
Returns
----------
bool: True if it is a Gradio cache path, False otherwise
"""
path = os.path.normpath(in_path)
path = path.split(os.sep)
for p in path:
if p.lower() == "gradio":
return True
return False

def transcribe_file(self,
files: Optional[List] = None,
input_folder_path: Optional[str] = None,
Expand Down Expand Up @@ -266,11 +286,12 @@ def transcribe_file(self,
with tempfile.TemporaryDirectory(dir=".", prefix="_tmp_input") as tmp_folder:
tmp_files = []
for f in files:
if "PYTEST_CURRENT_TEST" in os.environ:
# during tests we do not want to move/delete our input files
tmp_files.append(shutil.copy(f, tmp_folder))
else:
if self.isGradioCachePath(f):
# likely a gradio cached file, so move it
tmp_files.append(shutil.move(f, tmp_folder))
else:
# likely no gradio cached file, use the original one
tmp_files.append(f)
for file in tmp_files:
transcribed_segments, time_for_task = self.run(
file,
Expand Down

0 comments on commit ccd3777

Please sign in to comment.