Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the operations for lower memory usage to help on container d… #73

Merged
merged 1 commit into from
Mar 31, 2025

Conversation

jnation3406
Copy link
Contributor

…eployment

This combines a number of small fixes with memory optimizations for each operation. I'll make inline comments on the specifics

@jnation3406 jnation3406 requested a review from LTDakin March 31, 2025 02:56
cache.set(f'operation_{self.cache_key}_output', output_data, CACHE_DURATION)
self.set_operation_progress(1.0)
self.set_status('COMPLETED')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rearranged the order of these so that settings status to COMPLETED is last. This is to prevent a race condition where the frontend polls until COMPLETED is reached but could've read that state before the output data was available in the cache and then it would stop polling for the output data and never get it.

@@ -31,8 +31,9 @@ def __init__(self, cache_key: str, data: np.array, comment: str=None) -> None:
comment (str): Optionally add a comment to add to the FITS file.
"""
self.datalab_id = cache_key
self.primary_hdu = fits.PrimaryHDU(header=fits.Header([('KEY', cache_key)]))
self.image_hdu = fits.ImageHDU(data=data, name='SCI')
self.primary_hdu = fits.PrimaryHDU(header=fits.Header([('DLAB_KEY', cache_key)]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the datalab key to DLAB_KEY header key so its more specific

@@ -22,7 +22,7 @@ class FITSOutputHandler():
data (np.array): The data for the image HDU.
"""

def __init__(self, cache_key: str, data: np.array, comment: str=None) -> None:
def __init__(self, cache_key: str, data: np.array, comment: str=None, data_header: fits.Header=None) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow this function to accept a data (SCI) hdu header to place into the image hdu its saving. We should probably extend this a bit more to accept saving more HDUs - I can see wanted to copy over the CAT HDU for certain operations that are looking at the same area, which is basically all of them, so that we can load the source catalog when you make them big. Like if you take a median and then subtract that out from an image you might do that to make the sources disappear or make a single source pop, and seeing the source catalog overlayed might be nice.

@@ -27,15 +29,27 @@ def __init__(self, basename: str, source: str = None) -> None:
self.source = source
self.exit_stack = ExitStack()
self.fits_file = self.exit_stack.enter_context(get_fits(basename, source))
self.sci_data = get_hdu(self.fits_file, 'SCI').data
self.sci_hdu = get_hdu(self.fits_file, 'SCI')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were already getting the SCI HDU and copying it out of the file, so this just saves the reference to it in this class so we can pull out the header if we need it without opening the file and copying the HDU again.

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing this class to be used as a context manager in some cases, and explicitly doing a garbage collection when the file and data exits context. This is for those operations that operate on files sequentially like subtraction and normalization so we should be able to support infinite files without running out of memory.

@@ -101,16 +101,20 @@ def create_tif(fits_paths: np.ndarray, tif_path, color=False, zmin=None, zmax=No
max_height, max_width = max(get_fits_dimensions(fp) for fp in fits_paths)
fits_to_img(fits_paths, tif_path, TIFF_EXTENSION, width=max_width, height=max_height, color=color, zmin=zmin, zmax=zmax)

def crop_arrays(array_list: list):
def crop_arrays(array_list: list, flatten=False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to optionally flatten the data in place with ravel() when cropping, and also return the cropped data size with the data arrays since it might be useful later.

# Retry network connection errors 3 times, all other exceptions are not retried
def should_retry(retries_so_far, exception):
return retries_so_far < 3 and isinstance(exception, RequestException)

@dramatiq.actor(retry_when=should_retry)
@dramatiq.actor(retry_when=should_retry, time_limit=TIME_LIMIT)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default time limit is 10 minutes - increasing to 1 hour since I hit the time limit doing a 20 image normalization locally. We could make it a little higher but eventually we want hung operations to time out and fail so it doesn't block stuff forever...



@receiver(post_delete, sender=DataOperation)
def cb_dataoperation_post_delete(sender, instance, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After an operation is deleted, if it was FAILED then clear all the cached values for this operation so if it is requested again it will be attempted again instead of just immediately returning the FAILED status.

@@ -54,9 +54,9 @@ def operate(self):
self.set_operation_progress(0.5 * (index / len(input_list)))

# Creating the Median array
cropped_data = crop_arrays([image.sci_data for image in input_fits_list])
stacked_ndarray = np.stack(cropped_data, axis=2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.stack doubles memory usage of the image data, and can be avoided by taking a median of flattened arrays and then reshaping back to the original cropped size.

output = FITSOutputHandler(f'{self.cache_key}', normalized_image, comment).create_and_save_data_products(index=index)
output_files.append(output)
self.set_operation_progress(0.5 + index/len(input_fits_list) * 0.4)
for index, input in enumerate(input_list, start=1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the ordering of this operation so it acts on the input images sequentially and cleans up memory between image. This should allow it to operate on large numbers of images without running out of memory.

for index, input_image in enumerate(input_fits_list, start=1):
# crop the input_image and subtraction_image to the same size
input_image, subtraction_image = crop_arrays([input_image.sci_data, subtraction_fits.sci_data])
for index, input in enumerate(input_files, start=1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as normalization which can operate on images sequentially.

@@ -54,12 +53,11 @@ def operate(self):
input_fits_list.append(InputDataHandler(input['basename'], input['source']))
self.set_operation_progress(0.5 * (index / len(input_files)))

cropped_data = crop_arrays([image.sci_data for image in input_fits_list])
stacked_ndarray = np.stack(cropped_data, axis=2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again we can avoid doubling the memory with the stack by just summing the cropped arrays directly

@@ -111,17 +112,6 @@ def _align_images(self, fits_files: list[str]) -> list[str]:
return fits_files

return aligned_images

# Currently storing the output fits SCI HDU as a 3D ndarray consisting of each input's SCI data
def _create_3d_array(self, input_handlers: list[InputDataHandler]) -> np.ndarray:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the output fits file that was unnecessary

create_jpgs(aligned_images, large_jpg_path, small_jpg_path, color=True, zmin=zmin_list, zmax=zmax_list)
except Exception as ex:
# Catches exceptions in the fits2image methods to report back to frontend
raise ClientAlertException(ex)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fits2image library throws some exceptions if the input images shapes don't match well - This try/except will catch those exceptions and re-raise them so the UI can show the error to the user instead of the generic error.

Copy link
Contributor

@LTDakin LTDakin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome, thanks for working on this. Can see these changes drastically improving the operation of datalab

@LTDakin LTDakin merged commit 77400a6 into main Mar 31, 2025
3 checks passed
@LTDakin LTDakin deleted the fix/optimization_pass_1 branch March 31, 2025 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants