Skip to content
Snippets Groups Projects
Commit 4a6836b2 authored by payno's avatar payno
Browse files

Merge branch 'allow_setting_specific_gpu' into 'main'

Allow setting specific gpu

See merge request !11
parents eeb63de8 1896e1ca
No related branches found
No related tags found
1 merge request!11Allow setting specific gpu
Pipeline #144412 passed
from .executor import submit # noqa F401
__version__ = "0.3.2"
__version__ = "0.3.3"
......@@ -8,6 +8,8 @@ from .utils import has_scancel_available, has_scontrol_available
import time
import logging
from uuid import uuid4
from packaging import version
from platform import python_version
_logger = logging.getLogger(__name__)
......@@ -231,7 +233,8 @@ class SBatchScriptJob(ScriptJob):
super()._write_script_preprocessing_lines(file_object=file_object)
# handle first
slurm_lines, pre_processing_lines = self.interpret_slurm_config(
self._slurm_config
self._slurm_config,
self._sbatch_extra_params,
)
# define out file
output_file_path = self._get_output_file_path()
......@@ -257,7 +260,29 @@ class SBatchScriptJob(ScriptJob):
return None
@staticmethod
def interpret_slurm_config(slurm_config: dict) -> tuple:
def strip_gpu_card_name(gpu_name: str):
"""
today the name of the gpu return by sinfo are rejected when we are using the -C option.
Looks like they are prefixed and postfix. For now we strip the extra information but more coherence
is needed.
"""
if version.parse(python_version()) >= version.parse("3.9"):
gpu_name = gpu_name.removeprefix("nvidia_")
gpu_name = gpu_name.removeprefix("tesla_")
gpu_name = gpu_name.removesuffix("-sxm2-32gb")
gpu_name = gpu_name.removesuffix("-pcie-32gb")
else:
gpu_name = gpu_name.replace("nvidia_", "")
gpu_name = gpu_name.replace("tesla_", "")
gpu_name = gpu_name.replace("-sxm2-32gb", "")
gpu_name = gpu_name.replace("-pcie-32gb", "")
return gpu_name
@staticmethod
def interpret_slurm_config(
slurm_config: dict, sbatch_extra_params: Optional[dict] = None
) -> tuple:
"""
convert a slurm configuration dictory to a tuple of two tuples.
The first tuple will provide the lines to add to the shell script for sbtach (ressources specification)
......@@ -269,6 +294,8 @@ class SBatchScriptJob(ScriptJob):
raise TypeError(
f"slurm_config is expected to be a dict. {type(slurm_config)} provided"
)
if sbatch_extra_params is None:
sbatch_extra_params = {}
slurm_ressources = []
preprocessing = []
for key, value in slurm_config.items():
......@@ -283,7 +310,12 @@ class SBatchScriptJob(ScriptJob):
elif key == "partition":
slurm_ressources.append(f"#SBATCH -p {value}")
elif key == "n_gpus":
slurm_ressources.append(f"#SBATCH --gres=gpu:{value}")
gpu_line = f"#SBATCH --gres=gpu:{value}"
gpu_card = sbatch_extra_params.get("gpu_card", None)
if gpu_card is not None:
gpu_card = SBatchScriptJob.strip_gpu_card_name(gpu_card)
gpu_line += f" -C {gpu_card}"
slurm_ressources.append(gpu_line)
elif key == "job_name":
slurm_ressources.append(f"#SBATCH -J '{value}'")
elif key == "walltime":
......
......@@ -192,8 +192,30 @@ def test_interpret_slurm_config(caplog):
"tomotools",
"pycharm/11.7.1",
),
}
},
sbatch_extra_params={
"gpu_card": "a40"
}, # just to make sure this doesn't add gpu options of no gpu requested
) == (
("#SBATCH --mem=10GB",),
("module load tomotools", "module load pycharm/11.7.1"),
)
assert SBatchScriptJob.interpret_slurm_config(
slurm_config={"n_gpus": "3", "partition": "my_partition"},
sbatch_extra_params={"gpu_card": "a40"},
) == (
(
"#SBATCH --gres=gpu:3 -C a40",
"#SBATCH -p my_partition",
),
(),
)
def test_strip_gpu_card_name():
"""test `strip_gpu_card_name` function"""
assert SBatchScriptJob.strip_gpu_card_name("a40") == "a40"
assert SBatchScriptJob.strip_gpu_card_name("nvidia_a40") == "a40"
assert SBatchScriptJob.strip_gpu_card_name("tesla_56") == "56"
assert SBatchScriptJob.strip_gpu_card_name("tesla_v100-pcie-32gb") == "v100"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment