Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 53 additions & 90 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,9 @@ class DAGScheduler(
private val nextStageId = new AtomicInteger(0)

private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage]
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo]

// Stages we need to run whose parents aren't done
private[scheduler] val waitingStages = new HashSet[Stage]
Expand All @@ -101,9 +98,6 @@ class DAGScheduler(
// Stages that must be resubmitted due to fetch failures
private[scheduler] val failedStages = new HashSet[Stage]

// Missing tasks from each stage
private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]

private[scheduler] val activeJobs = new HashSet[ActiveJob]

// Contains the locations that each RDD's partitions are cached on
Expand Down Expand Up @@ -223,7 +217,6 @@ class DAGScheduler(
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
stageToInfos(stage) = StageInfo.fromStage(stage)
stage
}

Expand Down Expand Up @@ -315,13 +308,12 @@ class DAGScheduler(
*/
private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) {
def updateJobIdStageIdMapsList(stages: List[Stage]) {
if (!stages.isEmpty) {
if (stages.nonEmpty) {
val s = stages.head
stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
s.jobIds += jobId
jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
val parents = getParentStages(s.rdd, jobId)
val parentsWithoutThisJobId = parents.filter(p =>
!stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
val parents: List[Stage] = getParentStages(s.rdd, jobId)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the type annotation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not obvious to me what getParentStages return just staring at the code. Sometimes when we refer to stages we use a set of integers, sometimes just Stage objects.

val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
}
}
Expand All @@ -333,16 +325,15 @@ class DAGScheduler(
* handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
*
* @param job The job whose state to cleanup.
* @param resultStage Specifies the result stage for the job; if set to None, this method
* searches resultStagesToJob to find and cleanup the appropriate result stage.
*/
private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) {
private def cleanupStateForJobAndIndependentStages(job: ActiveJob) {
val registeredStages = jobIdToStageIds.get(job.jobId)
if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
logError("No stages registered for job " + job.jobId)
} else {
stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
case (stageId, jobSet) =>
stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
case (stageId, stage) =>
val jobSet = stage.jobIds
if (!jobSet.contains(job.jobId)) {
logError(
"Job %d not registered for stage %d even though that stage was registered for the job"
Expand All @@ -355,14 +346,9 @@ class DAGScheduler(
logDebug("Removing running stage %d".format(stageId))
runningStages -= stage
}
stageToInfos -= stage
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
shuffleToMapStage.remove(k)
}
if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) {
logDebug("Removing pending status for stage %d".format(stageId))
}
pendingTasks -= stage
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
waitingStages -= stage
Expand All @@ -374,7 +360,6 @@ class DAGScheduler(
}
// data structures based on StageId
stageIdToStage -= stageId
stageIdToJobIds -= stageId

ShuffleMapTask.removeStage(stageId)
ResultTask.removeStage(stageId)
Expand All @@ -393,19 +378,7 @@ class DAGScheduler(
jobIdToStageIds -= job.jobId
jobIdToActiveJob -= job.jobId
activeJobs -= job

if (resultStage.isEmpty) {
// Clean up result stages.
val resultStagesForJob = resultStageToJob.keySet.filter(
stage => resultStageToJob(stage).jobId == job.jobId)
if (resultStagesForJob.size != 1) {
logWarning(
s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
}
resultStageToJob --= resultStagesForJob
} else {
resultStageToJob -= resultStage.get
}
job.finalStage.resultOfJob = None
}

/**
Expand Down Expand Up @@ -591,9 +564,10 @@ class DAGScheduler(
job.listener.jobFailed(exception)
} finally {
val s = job.finalStage
stageIdToJobIds -= s.id // clean up data structures that were populated for a local job,
stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
stageToInfos -= s // completion events or stage abort
// clean up data structures that were populated for a local job,
// but that won't get cleaned up via the normal paths through
// completion events or stage abort
stageIdToStage -= s.id
jobIdToStageIds -= job.jobId
listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
}
Expand All @@ -605,12 +579,8 @@ class DAGScheduler(
// That should take care of at least part of the priority inversion problem with
// cross-job dependencies.
private def activeJobForStage(stage: Stage): Option[Int] = {
if (stageIdToJobIds.contains(stage.id)) {
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
jobsThatUseStage.find(jobIdToActiveJob.contains)
} else {
None
}
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted
jobsThatUseStage.find(jobIdToActiveJob.contains)
}

private[scheduler] def handleJobGroupCancelled(groupId: String) {
Expand Down Expand Up @@ -642,9 +612,8 @@ class DAGScheduler(
// is in the process of getting stopped.
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
runningStages.foreach { stage =>
val info = stageToInfos(stage)
info.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(info))
stage.info.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.info))
}
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}
Expand Down Expand Up @@ -690,7 +659,7 @@ class DAGScheduler(
} else {
jobIdToActiveJob(jobId) = job
activeJobs += job
resultStageToJob(finalStage) = job
finalStage.resultOfJob = Some(job)
listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
properties))
submitStage(finalStage)
Expand Down Expand Up @@ -727,8 +696,7 @@ class DAGScheduler(
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
myPending.clear()
stage.pendingTasks.clear()
var tasks = ArrayBuffer[Task[_]]()
if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
Expand All @@ -737,7 +705,7 @@ class DAGScheduler(
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = resultStageToJob(stage)
val job = stage.resultOfJob.get
for (id <- 0 until job.numPartitions if !job.finished(id)) {
val partition = job.partitions(id)
val locs = getPreferredLocs(stage.rdd, partition)
Expand All @@ -758,7 +726,7 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do i need to call stage.pendingTasks = new HashSet here? @kayousterhout @markhamstra

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you should clear it here, I think the issue is if a stage gets resubmitted.

// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
Expand All @@ -778,11 +746,11 @@ class DAGScheduler(
}

logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
stage.pendingTasks ++= tasks
logDebug("New pending tasks: " + stage.pendingTasks)
taskScheduler.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stageToInfos(stage).submissionTime = Some(clock.getTime())
stage.info.submissionTime = Some(clock.getTime())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
Expand All @@ -807,13 +775,13 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)

def markStageAsFinished(stage: Stage) = {
val serviceTime = stageToInfos(stage).submissionTime match {
val serviceTime = stage.info.submissionTime match {
case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stageToInfos(stage).completionTime = Some(clock.getTime())
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
stage.info.completionTime = Some(clock.getTime())
listenerBus.post(SparkListenerStageCompleted(stage.info))
runningStages -= stage
}
event.reason match {
Expand All @@ -822,18 +790,18 @@ class DAGScheduler(
// TODO: fail the stage if the accumulator update fails...
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
}
pendingTasks(stage) -= task
stage.pendingTasks -= task
task match {
case rt: ResultTask[_, _] =>
resultStageToJob.get(stage) match {
stage.resultOfJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
markStageAsFinished(stage)
cleanupStateForJobAndIndependentStages(job, Some(stage))
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
}

Expand All @@ -860,7 +828,7 @@ class DAGScheduler(
} else {
stage.addOutputLoc(smt.partitionId, status)
}
if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) {
if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) {
markStageAsFinished(stage)
logInfo("looking for newly runnable stages")
logInfo("running: " + runningStages)
Expand Down Expand Up @@ -909,7 +877,7 @@ class DAGScheduler(

case Resubmitted =>
logInfo("Resubmitted " + task + ", so marking it as still running")
pendingTasks(stage) += task
stage.pendingTasks += task

case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
Expand Down Expand Up @@ -994,13 +962,14 @@ class DAGScheduler(
}

private[scheduler] def handleStageCancellation(stageId: Int) {
if (stageIdToJobIds.contains(stageId)) {
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray
jobsThatUseStage.foreach(jobId => {
handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId))
})
} else {
logInfo("No active jobs to kill for Stage " + stageId)
stageIdToStage.get(stageId) match {
case Some(stage) =>
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
jobsThatUseStage.foreach { jobId =>
handleJobCancellation(jobId, s"because Stage $stageId was cancelled")
}
case None =>
logInfo("No active jobs to kill for Stage " + stageId)
}
submitWaitingStages()
}
Expand All @@ -1009,8 +978,8 @@ class DAGScheduler(
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
failJobAndIndependentStages(jobIdToActiveJob(jobId),
"Job %d cancelled %s".format(jobId, reason), None)
failJobAndIndependentStages(
jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason))
}
submitWaitingStages()
}
Expand All @@ -1024,26 +993,21 @@ class DAGScheduler(
// Skip all the actions if the stage has been removed.
return
}
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
stageToInfos(failedStage).completionTime = Some(clock.getTime())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason",
Some(resultStage))
val dependentJobs: Seq[ActiveJob] =
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
failedStage.info.completionTime = Some(clock.getTime())
for (job <- dependentJobs) {
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
}
if (dependentStages.isEmpty) {
if (dependentJobs.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
}
}

/**
* Fails a job and all stages that are only used by that job, and cleans up relevant state.
*
* @param resultStage The result stage for the job, if known. Used to cleanup state for the job
* slightly more efficiently than when not specified.
*/
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
resultStage: Option[Stage]) {
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
val error = new SparkException(failureReason)
var ableToCancelStages = true

Expand All @@ -1057,7 +1021,7 @@ class DAGScheduler(
logError("No stages registered for job " + job.jobId)
}
stages.foreach { stageId =>
val jobsForStage = stageIdToJobIds.get(stageId)
val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds)
if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) {
logError(
"Job %d not registered for stage %d even though that stage was registered for the job"
Expand All @@ -1071,9 +1035,8 @@ class DAGScheduler(
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
val stageInfo = stageToInfos(stage)
stageInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
stage.info.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stage.info))
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
Expand All @@ -1086,7 +1049,7 @@ class DAGScheduler(

if (ableToCancelStages) {
job.listener.jobFailed(error)
cleanupStateForJobAndIndependentStages(job, resultStage)
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}
}
Expand Down
Loading