Commit bd6da794 authored by Pierre Paleo's avatar Pierre Paleo

Attempt at mitigating cuda memory issue

parent 7b20711b
......@@ -7,7 +7,7 @@ computations.py: determine computational needs, chunking method to be used, etc.
from silx.image.tomography import get_next_power
def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=True):
def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=True, warn_from_GB=None):
"""
Estimate the memory (RAM) needed for a reconstruction.
......@@ -15,6 +15,17 @@ def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=
-----------
radios_and_slices: bool
Whether radios and slices will co-exist in memory (meaning more memory usage)
warn_from_GB: float, optional
Amount of memory in GB from where a warning flag will be raised.
Default is None
If set to a number, the result will be in the form (estimated_memory_GB, warning)
where 'warning' is a boolean indicating whether memory allocation/transfer might be problematic.
Notes
-----
It seems that Cuda does not allow allocating and/or transferring more than 16384 MiB (17.18 GB).
If warn_from_GB is not None, then the result is in the form (estimated_memory_GB, warning)
where warning is a boolean indicating wheher memory allocation/transfer might be problematic.
"""
dataset = process_config.dataset_infos
nabu_config = process_config.nabu_config
......@@ -25,6 +36,7 @@ def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=
Ny = chunk_size
total_memory_needed = 0
memory_warning = False
# Read data
# ----------
......@@ -33,6 +45,9 @@ def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=
data_volume_size = Nx * Ny * Na * 4
data_image_size = Nx * Ny * 4
total_memory_needed += data_volume_size
if (warn_from_GB is not None) and data_volume_size/1e9 > warn_from_GB:
memory_warning = True
print("warn for %d" % chunk_size)
# CCD processing
# ---------------
......@@ -82,7 +97,10 @@ def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=
reconstructed_volume_size = Nx_rec * Ny_rec * Nz_rec * 4 # float32
total_memory_needed += reconstructed_volume_size
return total_memory_needed
if warn_from_GB is None:
return total_memory_needed
else:
return (total_memory_needed, memory_warning)
def estimate_chunk_size(available_memory_GB, process_config, chunk_step=50):
"""
......@@ -109,11 +127,14 @@ def estimate_chunk_size(available_memory_GB, process_config, chunk_step=50):
chunk_size = chunk_step
last_good_chunk_size = chunk_size
while True:
required_mem = estimate_required_memory(
process_config, chunk_size=chunk_size, radios_and_slices=radios_and_slices
(required_mem, mem_warn) = estimate_required_memory(
process_config,
chunk_size=chunk_size,
radios_and_slices=radios_and_slices,
warn_from_GB=17 # 2**32 elements - see estimate_required_memory docstring note
)
required_mem_GB = required_mem / 1e9
if required_mem_GB > available_memory_GB or chunk_size > max_dz:
if required_mem_GB > available_memory_GB or chunk_size > max_dz or mem_warn:
break
last_good_chunk_size = chunk_size
chunk_size += chunk_step
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment