diff --git a/dask_jobqueue/core.py b/dask_jobqueue/core.py index 8a006166..e3377183 100644 --- a/dask_jobqueue/core.py +++ b/dask_jobqueue/core.py @@ -143,43 +143,43 @@ def __init__( super().__init__() + default_config_name = self.default_config_name() if config_name is None: - config_name = getattr(type(self), "config_name") - if config_name is None: - raise ValueError( - "Looks like you are trying to create a class that inherits from dask_jobqueue.core.Job. " - "If that is the case, you need to:\n" - "- set the 'config_name' class variable to a non-None value\n" - "- create a section in jobqueue.yaml with the value of 'config_name'\n" - "If that is not the case, please open an issue in https://github.com/dask/dask-jobqueue/issues." - ) + config_name = default_config_name + self.config_name = config_name if job_name is None: - job_name = dask.config.get("jobqueue.%s.name" % config_name) + job_name = dask.config.get("jobqueue.%s.name" % self.config_name) if cores is None: - cores = dask.config.get("jobqueue.%s.cores" % config_name) + cores = dask.config.get("jobqueue.%s.cores" % self.config_name) if memory is None: - memory = dask.config.get("jobqueue.%s.memory" % config_name) + memory = dask.config.get("jobqueue.%s.memory" % self.config_name) if processes is None: - processes = dask.config.get("jobqueue.%s.processes" % config_name) + processes = dask.config.get("jobqueue.%s.processes" % self.config_name) if interface is None: - interface = dask.config.get("jobqueue.%s.interface" % config_name) + interface = dask.config.get("jobqueue.%s.interface" % self.config_name) if death_timeout is None: - death_timeout = dask.config.get("jobqueue.%s.death-timeout" % config_name) + death_timeout = dask.config.get( + "jobqueue.%s.death-timeout" % self.config_name + ) if local_directory is None: local_directory = dask.config.get( - "jobqueue.%s.local-directory" % config_name + "jobqueue.%s.local-directory" % self.config_name ) if extra is None: - extra = dask.config.get("jobqueue.%s.extra" % config_name) + extra = dask.config.get("jobqueue.%s.extra" % self.config_name) if env_extra is None: - env_extra = dask.config.get("jobqueue.%s.env-extra" % config_name) + env_extra = dask.config.get("jobqueue.%s.env-extra" % self.config_name) if header_skip is None: - header_skip = dask.config.get("jobqueue.%s.header-skip" % config_name, ()) + header_skip = dask.config.get( + "jobqueue.%s.header-skip" % self.config_name, () + ) if log_directory is None: - log_directory = dask.config.get("jobqueue.%s.log-directory" % config_name) + log_directory = dask.config.get( + "jobqueue.%s.log-directory" % self.config_name + ) if shebang is None: - shebang = dask.config.get("jobqueue.%s.shebang" % config_name) + shebang = dask.config.get("jobqueue.%s.shebang" % self.config_name) if cores is None or memory is None: job_class_name = self.__class__.__name__ @@ -191,7 +191,7 @@ def __init__( ) ) - # This attribute should be overridden + # This attribute should be set in the derived class self.job_header = None if interface: @@ -239,6 +239,18 @@ def __init__( if not os.path.exists(self.log_directory): os.makedirs(self.log_directory) + @classmethod + def default_config_name(cls): + config_name = getattr(cls, "config_name", None) + if config_name is None: + raise ValueError( + "The class {} is required to have 'config_name' class variable.\n" + "If you have created this class, please add a 'config_name' class variable.\n" + "If not this may be a bug, feel free to create an issue at: " + "https://github.com/dask/dask-jobqueue/issues/new".format(cls) + ) + return config_name + def job_script(self): """ Construct a job submission script """ header = "\n".join( @@ -392,8 +404,6 @@ class JobQueueCluster(SpecCluster): cluster_parameters=cluster_parameters ) - job_cls = None - def __init__( self, n_workers=0, @@ -414,18 +424,29 @@ def __init__( **kwargs ): self.status = "created" + + default_job_cls = getattr(type(self), "job_cls", None) + self.job_cls = default_job_cls if job_cls is not None: self.job_cls = job_cls if self.job_cls is None: raise ValueError( - "You must provide a Job type like PBSJob, SLURMJob, " - "or SGEJob with the job_cls= argument." + "You need to specify a Job type. Two cases:\n" + "- you are inheriting from JobQueueCluster (most likely): you need to add a 'job_cls' class variable " + "in your JobQueueCluster-derived class {}\n" + "- you are using JobQueueCluster directly (less likely, only useful for tests): " + "please explicitly pass a Job type through the 'job_cls' parameter.".format( + type(self) + ) ) - if config_name: - if interface is None: - interface = dask.config.get("jobqueue.%s.interface" % config_name) + default_config_name = self.job_cls.default_config_name() + if config_name is None: + config_name = default_config_name + + if interface is None: + interface = dask.config.get("jobqueue.%s.interface" % config_name) scheduler = { "cls": Scheduler, # Use local scheduler for now @@ -437,8 +458,8 @@ def __init__( "security": security, }, } - if config_name: - kwargs["config_name"] = config_name + + kwargs["config_name"] = config_name kwargs["interface"] = interface kwargs["protocol"] = protocol kwargs["security"] = security diff --git a/dask_jobqueue/htcondor.py b/dask_jobqueue/htcondor.py index 0735a569..ff4b6d7a 100644 --- a/dask_jobqueue/htcondor.py +++ b/dask_jobqueue/htcondor.py @@ -31,28 +31,29 @@ class HTCondorJob(Job): # Python (can't find its libs), so we have to go through the shell. executable = "/bin/sh" - def __init__( - self, *args, disk=None, job_extra=None, config_name="htcondor", **kwargs - ): + config_name = "htcondor" + + def __init__(self, *args, disk=None, job_extra=None, config_name=None, **kwargs): + super().__init__(*args, config_name=config_name, **kwargs) + if disk is None: - disk = dask.config.get("jobqueue.%s.disk" % config_name) + disk = dask.config.get("jobqueue.%s.disk" % self.config_name) if disk is None: raise ValueError( "You must specify how much disk to use per job like ``disk='1 GB'``" ) self.worker_disk = parse_bytes(disk) if job_extra is None: - self.job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name, {}) + self.job_extra = dask.config.get( + "jobqueue.%s.job-extra" % self.config_name, {} + ) else: self.job_extra = job_extra - # Instantiate args and parameters from parent abstract class - super().__init__(*args, config_name=config_name, **kwargs) - env_extra = kwargs.get("env_extra", None) if env_extra is None: env_extra = dask.config.get( - "jobqueue.%s.env-extra" % config_name, default=[] + "jobqueue.%s.env-extra" % self.config_name, default=[] ) self.env_dict = self.env_lines_to_dict(env_extra) self.env_dict["JOB_ID"] = "$F(MY.JobId)" diff --git a/dask_jobqueue/local.py b/dask_jobqueue/local.py index 3ac2fd12..8aada893 100644 --- a/dask_jobqueue/local.py +++ b/dask_jobqueue/local.py @@ -32,7 +32,7 @@ def __init__( resource_spec=None, walltime=None, job_extra=None, - config_name="local", + config_name=None, **kwargs ): # Instantiate args and parameters from parent abstract class diff --git a/dask_jobqueue/lsf.py b/dask_jobqueue/lsf.py index 1dfda0ff..d6eb11ea 100644 --- a/dask_jobqueue/lsf.py +++ b/dask_jobqueue/lsf.py @@ -17,6 +17,7 @@ class LSFJob(Job): submit_command = "bsub" cancel_command = "bkill" + config_name = "lsf" def __init__( self, @@ -28,34 +29,33 @@ def __init__( walltime=None, job_extra=None, lsf_units=None, - config_name="lsf", + config_name=None, use_stdin=None, **kwargs ): + super().__init__(*args, config_name=config_name, **kwargs) + if queue is None: - queue = dask.config.get("jobqueue.%s.queue" % config_name) + queue = dask.config.get("jobqueue.%s.queue" % self.config_name) if project is None: - project = dask.config.get("jobqueue.%s.project" % config_name) + project = dask.config.get("jobqueue.%s.project" % self.config_name) if ncpus is None: - ncpus = dask.config.get("jobqueue.%s.ncpus" % config_name) + ncpus = dask.config.get("jobqueue.%s.ncpus" % self.config_name) if mem is None: - mem = dask.config.get("jobqueue.%s.mem" % config_name) + mem = dask.config.get("jobqueue.%s.mem" % self.config_name) if walltime is None: - walltime = dask.config.get("jobqueue.%s.walltime" % config_name) + walltime = dask.config.get("jobqueue.%s.walltime" % self.config_name) if job_extra is None: - job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name) + job_extra = dask.config.get("jobqueue.%s.job-extra" % self.config_name) if lsf_units is None: - lsf_units = dask.config.get("jobqueue.%s.lsf-units" % config_name) + lsf_units = dask.config.get("jobqueue.%s.lsf-units" % self.config_name) if use_stdin is None: - use_stdin = dask.config.get("jobqueue.%s.use-stdin" % config_name) + use_stdin = dask.config.get("jobqueue.%s.use-stdin" % self.config_name) if use_stdin is None: use_stdin = lsf_version() < "10" self.use_stdin = use_stdin - # Instantiate args and parameters from parent abstract class - super().__init__(*args, config_name=config_name, **kwargs) - header_lines = [] # LSF header build if self.name is not None: diff --git a/dask_jobqueue/oar.py b/dask_jobqueue/oar.py index f59f8512..1d5defc8 100644 --- a/dask_jobqueue/oar.py +++ b/dask_jobqueue/oar.py @@ -14,6 +14,7 @@ class OARJob(Job): submit_command = "oarsub" cancel_command = "oardel" job_id_regexp = r"OAR_JOB_ID=(?P\d+)" + config_name = "oar" def __init__( self, @@ -23,21 +24,23 @@ def __init__( resource_spec=None, walltime=None, job_extra=None, - config_name="oar", + config_name=None, **kwargs ): + super().__init__(*args, config_name=config_name, **kwargs) + if queue is None: - queue = dask.config.get("jobqueue.%s.queue" % config_name) + queue = dask.config.get("jobqueue.%s.queue" % self.config_name) if project is None: - project = dask.config.get("jobqueue.%s.project" % config_name) + project = dask.config.get("jobqueue.%s.project" % self.config_name) if resource_spec is None: - resource_spec = dask.config.get("jobqueue.%s.resource-spec" % config_name) + resource_spec = dask.config.get( + "jobqueue.%s.resource-spec" % self.config_name + ) if walltime is None: - walltime = dask.config.get("jobqueue.%s.walltime" % config_name) + walltime = dask.config.get("jobqueue.%s.walltime" % self.config_name) if job_extra is None: - job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name) - - super().__init__(*args, config_name=config_name, **kwargs) + job_extra = dask.config.get("jobqueue.%s.job-extra" % self.config_name) header_lines = [] if self.job_name is not None: diff --git a/dask_jobqueue/pbs.py b/dask_jobqueue/pbs.py index f8b3acd1..e819838a 100644 --- a/dask_jobqueue/pbs.py +++ b/dask_jobqueue/pbs.py @@ -47,25 +47,26 @@ def __init__( resource_spec=None, walltime=None, job_extra=None, - config_name="pbs", + config_name=None, **kwargs ): + super().__init__(*args, config_name=config_name, **kwargs) + if queue is None: - queue = dask.config.get("jobqueue.%s.queue" % config_name) + queue = dask.config.get("jobqueue.%s.queue" % self.config_name) if resource_spec is None: - resource_spec = dask.config.get("jobqueue.%s.resource-spec" % config_name) + resource_spec = dask.config.get( + "jobqueue.%s.resource-spec" % self.config_name + ) if walltime is None: - walltime = dask.config.get("jobqueue.%s.walltime" % config_name) + walltime = dask.config.get("jobqueue.%s.walltime" % self.config_name) if job_extra is None: - job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name) + job_extra = dask.config.get("jobqueue.%s.job-extra" % self.config_name) if project is None: project = dask.config.get( - "jobqueue.%s.project" % config_name + "jobqueue.%s.project" % self.config_name ) or os.environ.get("PBS_ACCOUNT") - # Instantiate args and parameters from parent abstract class - super().__init__(*args, config_name=config_name, **kwargs) - # Try to find a project name from environment variable project = project or os.environ.get("PBS_ACCOUNT") diff --git a/dask_jobqueue/sge.py b/dask_jobqueue/sge.py index cc022b47..c8dc2816 100644 --- a/dask_jobqueue/sge.py +++ b/dask_jobqueue/sge.py @@ -10,6 +10,7 @@ class SGEJob(Job): submit_command = "qsub" cancel_command = "qdel" + config_name = "sge" def __init__( self, @@ -19,21 +20,23 @@ def __init__( resource_spec=None, walltime=None, job_extra=None, - config_name="sge", + config_name=None, **kwargs ): + super().__init__(*args, config_name=config_name, **kwargs) + if queue is None: - queue = dask.config.get("jobqueue.%s.queue" % config_name) + queue = dask.config.get("jobqueue.%s.queue" % self.config_name) if project is None: - project = dask.config.get("jobqueue.%s.project" % config_name) + project = dask.config.get("jobqueue.%s.project" % self.config_name) if resource_spec is None: - resource_spec = dask.config.get("jobqueue.%s.resource-spec" % config_name) + resource_spec = dask.config.get( + "jobqueue.%s.resource-spec" % self.config_name + ) if walltime is None: - walltime = dask.config.get("jobqueue.%s.walltime" % config_name) + walltime = dask.config.get("jobqueue.%s.walltime" % self.config_name) if job_extra is None: - job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name) - - super().__init__(*args, config_name=config_name, **kwargs) + job_extra = dask.config.get("jobqueue.%s.job-extra" % self.config_name) header_lines = [] if self.job_name is not None: @@ -114,4 +117,3 @@ class SGECluster(JobQueueCluster): job=job_parameters, cluster=cluster_parameters ) job_cls = SGEJob - config_name = "sge" diff --git a/dask_jobqueue/slurm.py b/dask_jobqueue/slurm.py index e17c85e2..ffbf61be 100644 --- a/dask_jobqueue/slurm.py +++ b/dask_jobqueue/slurm.py @@ -12,6 +12,7 @@ class SLURMJob(Job): # Override class variables submit_command = "sbatch" cancel_command = "scancel" + config_name = "slurm" def __init__( self, @@ -22,23 +23,23 @@ def __init__( job_cpu=None, job_mem=None, job_extra=None, - config_name="slurm", + config_name=None, **kwargs ): + super().__init__(*args, config_name=config_name, **kwargs) + if queue is None: - queue = dask.config.get("jobqueue.%s.queue" % config_name) + queue = dask.config.get("jobqueue.%s.queue" % self.config_name) if project is None: - project = dask.config.get("jobqueue.%s.project" % config_name) + project = dask.config.get("jobqueue.%s.project" % self.config_name) if walltime is None: - walltime = dask.config.get("jobqueue.%s.walltime" % config_name) + walltime = dask.config.get("jobqueue.%s.walltime" % self.config_name) if job_cpu is None: - job_cpu = dask.config.get("jobqueue.%s.job-cpu" % config_name) + job_cpu = dask.config.get("jobqueue.%s.job-cpu" % self.config_name) if job_mem is None: - job_mem = dask.config.get("jobqueue.%s.job-mem" % config_name) + job_mem = dask.config.get("jobqueue.%s.job-mem" % self.config_name) if job_extra is None: - job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name) - - super().__init__(*args, config_name=config_name, **kwargs) + job_extra = dask.config.get("jobqueue.%s.job-extra" % self.config_name) header_lines = [] # SLURM header build diff --git a/dask_jobqueue/tests/test_jobqueue_core.py b/dask_jobqueue/tests/test_jobqueue_core.py index bf67044d..e754d3af 100644 --- a/dask_jobqueue/tests/test_jobqueue_core.py +++ b/dask_jobqueue/tests/test_jobqueue_core.py @@ -3,9 +3,12 @@ import socket import sys import re +import psutil import pytest +import dask + from dask_jobqueue import ( JobQueueCluster, PBSCluster, @@ -15,12 +18,15 @@ LSFCluster, OARCluster, ) +from dask_jobqueue.core import Job +from dask_jobqueue.local import LocalCluster from dask_jobqueue.sge import SGEJob def test_errors(): - with pytest.raises(ValueError, match="Job type.*job_cls="): + match = re.compile("Job type.*job_cls", flags=re.DOTALL) + with pytest.raises(ValueError, match=match): JobQueueCluster(cores=4) @@ -207,3 +213,44 @@ def test_cluster_has_cores_and_memory(Cluster): with pytest.raises(ValueError, match=base_regex + r"cores=4, memory='\d+GB'"): Cluster(cores=4) + + +@pytest.mark.asyncio +async def test_config_interface(): + net_if_addrs = psutil.net_if_addrs() + interface = list(net_if_addrs.keys())[0] + with dask.config.set({"jobqueue.local.interface": interface}): + cluster = LocalCluster(cores=1, memory="2GB", asynchronous=True) + await cluster + expected = "'interface': {!r}".format(interface) + assert expected in str(cluster.scheduler_spec) + cluster.scale(1) + assert expected in str(cluster.worker_spec) + + +# TODO where to put these tests +def test_job_without_config_name(): + class MyJob(Job): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with pytest.raises(ValueError, match="config_name.+MyJob"): + MyJob(cores=1, memory="1GB") + + class MyJobWithNoneConfigName(MyJob): + config_name = None + + with pytest.raises(ValueError, match="config_name.+MyJobWithNoneConfigName"): + MyJobWithNoneConfigName(cores=1, memory="1GB") + + with pytest.raises(ValueError, match="config_name.+MyJobWithNoneConfigName"): + JobQueueCluster(job_cls=MyJobWithNoneConfigName, cores=1, memory="1GB") + + +def test_cluster_without_job_cls(): + class MyCluster(JobQueueCluster): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with pytest.raises(ValueError, match="job_cls.+MyCluster"): + MyCluster(cores=1, memory="1GB")