Skip to content

Commit 75524e2

Browse files
bors[bot]cuviperemilio
authored
Merge #1063
1063: core: Introduce ThreadPoolBuilder::use_current_thread. r=cuviper a=emilio This generalizes the approach used by targets that don't support threading like wasm, allowing the builder thread to be part of a new thread-pool. This PR: * Builds on top of the PoC implementation from that issue. * Renames the API as per the comments there. * Adds a way to clean up the WorkerThread storage once the pool is dropped. * Documents and tests the APIs. Feedback welcome. `clean_up_use_current_thread` is not a great name, but I think it's descriptive, and maybe good enough given it's a rather niche API for non-global pools? Co-authored-by: Josh Stone <cuviper@gmail.com> Co-authored-by: Emilio Cobos Álvarez <emilio@crisal.io>
2 parents dc7090a + 01d2800 commit 75524e2

File tree

4 files changed

+113
-21
lines changed

4 files changed

+113
-21
lines changed

rayon-core/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,7 @@ path = "tests/simple_panic.rs"
5353
[[test]]
5454
name = "scoped_threadpool"
5555
path = "tests/scoped_threadpool.rs"
56+
57+
[[test]]
58+
name = "use_current_thread"
59+
path = "tests/use_current_thread.rs"

rayon-core/src/lib.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ pub struct ThreadPoolBuildError {
147147
#[derive(Debug)]
148148
enum ErrorKind {
149149
GlobalPoolAlreadyInitialized,
150+
CurrentThreadAlreadyInPool,
150151
IOError(io::Error),
151152
}
152153

@@ -174,6 +175,9 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
174175
/// If RAYON_NUM_THREADS is invalid or zero will use the default.
175176
num_threads: usize,
176177

178+
/// The thread we're building *from* will also be part of the pool.
179+
use_current_thread: bool,
180+
177181
/// Custom closure, if any, to handle a panic that we cannot propagate
178182
/// anywhere else.
179183
panic_handler: Option<Box<PanicHandler>>,
@@ -227,6 +231,7 @@ impl Default for ThreadPoolBuilder {
227231
fn default() -> Self {
228232
ThreadPoolBuilder {
229233
num_threads: 0,
234+
use_current_thread: false,
230235
panic_handler: None,
231236
get_thread_name: None,
232237
stack_size: None,
@@ -437,6 +442,7 @@ impl<S> ThreadPoolBuilder<S> {
437442
spawn_handler: CustomSpawn::new(spawn),
438443
// ..self
439444
num_threads: self.num_threads,
445+
use_current_thread: self.use_current_thread,
440446
panic_handler: self.panic_handler,
441447
get_thread_name: self.get_thread_name,
442448
stack_size: self.stack_size,
@@ -529,6 +535,24 @@ impl<S> ThreadPoolBuilder<S> {
529535
self
530536
}
531537

538+
/// Use the current thread as one of the threads in the pool.
539+
///
540+
/// The current thread is guaranteed to be at index 0, and since the thread is not managed by
541+
/// rayon, the spawn and exit handlers do not run for that thread.
542+
///
543+
/// Note that the current thread won't run the main work-stealing loop, so jobs spawned into
544+
/// the thread-pool will generally not be picked up automatically by this thread unless you
545+
/// yield to rayon in some way, like via [`yield_now()`], [`yield_local()`], or [`scope()`].
546+
///
547+
/// # Local thread-pools
548+
///
549+
/// Using this in a local thread-pool means the registry will be leaked. In future versions
550+
/// there might be a way of cleaning up the current-thread state.
551+
pub fn use_current_thread(mut self) -> Self {
552+
self.use_current_thread = true;
553+
self
554+
}
555+
532556
/// Returns a copy of the current panic handler.
533557
fn take_panic_handler(&mut self) -> Option<Box<PanicHandler>> {
534558
self.panic_handler.take()
@@ -731,18 +755,22 @@ impl ThreadPoolBuildError {
731755
const GLOBAL_POOL_ALREADY_INITIALIZED: &str =
732756
"The global thread pool has already been initialized.";
733757

758+
const CURRENT_THREAD_ALREADY_IN_POOL: &str =
759+
"The current thread is already part of another thread pool.";
760+
734761
impl Error for ThreadPoolBuildError {
735762
#[allow(deprecated)]
736763
fn description(&self) -> &str {
737764
match self.kind {
738765
ErrorKind::GlobalPoolAlreadyInitialized => GLOBAL_POOL_ALREADY_INITIALIZED,
766+
ErrorKind::CurrentThreadAlreadyInPool => CURRENT_THREAD_ALREADY_IN_POOL,
739767
ErrorKind::IOError(ref e) => e.description(),
740768
}
741769
}
742770

743771
fn source(&self) -> Option<&(dyn Error + 'static)> {
744772
match &self.kind {
745-
ErrorKind::GlobalPoolAlreadyInitialized => None,
773+
ErrorKind::GlobalPoolAlreadyInitialized | ErrorKind::CurrentThreadAlreadyInPool => None,
746774
ErrorKind::IOError(e) => Some(e),
747775
}
748776
}
@@ -751,6 +779,7 @@ impl Error for ThreadPoolBuildError {
751779
impl fmt::Display for ThreadPoolBuildError {
752780
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
753781
match &self.kind {
782+
ErrorKind::CurrentThreadAlreadyInPool => CURRENT_THREAD_ALREADY_IN_POOL.fmt(f),
754783
ErrorKind::GlobalPoolAlreadyInitialized => GLOBAL_POOL_ALREADY_INITIALIZED.fmt(f),
755784
ErrorKind::IOError(e) => e.fmt(f),
756785
}
@@ -768,6 +797,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
768797
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
769798
let ThreadPoolBuilder {
770799
ref num_threads,
800+
ref use_current_thread,
771801
ref get_thread_name,
772802
ref panic_handler,
773803
ref stack_size,
@@ -792,6 +822,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
792822

793823
f.debug_struct("ThreadPoolBuilder")
794824
.field("num_threads", num_threads)
825+
.field("use_current_thread", use_current_thread)
795826
.field("get_thread_name", &get_thread_name)
796827
.field("panic_handler", &panic_handler)
797828
.field("stack_size", &stack_size)

rayon-core/src/registry.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -207,26 +207,7 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
207207
// is stubbed out, and we won't have to change anything if they do add real threading.
208208
let unsupported = matches!(&result, Err(e) if e.is_unsupported());
209209
if unsupported && WorkerThread::current().is_null() {
210-
let builder = ThreadPoolBuilder::new()
211-
.num_threads(1)
212-
.spawn_handler(|thread| {
213-
// Rather than starting a new thread, we're just taking over the current thread
214-
// *without* running the main loop, so we can still return from here.
215-
// The WorkerThread is leaked, but we never shutdown the global pool anyway.
216-
let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
217-
let registry = &*worker_thread.registry;
218-
let index = worker_thread.index;
219-
220-
unsafe {
221-
WorkerThread::set_current(worker_thread);
222-
223-
// let registry know we are ready to do work
224-
Latch::set(&registry.thread_infos[index].primed);
225-
}
226-
227-
Ok(())
228-
});
229-
210+
let builder = ThreadPoolBuilder::new().num_threads(1).use_current_thread();
230211
let fallback_result = Registry::new(builder);
231212
if fallback_result.is_ok() {
232213
return fallback_result;
@@ -300,6 +281,25 @@ impl Registry {
300281
stealer,
301282
index,
302283
};
284+
285+
if index == 0 && builder.use_current_thread {
286+
if !WorkerThread::current().is_null() {
287+
return Err(ThreadPoolBuildError::new(
288+
ErrorKind::CurrentThreadAlreadyInPool,
289+
));
290+
}
291+
// Rather than starting a new thread, we're just taking over the current thread
292+
// *without* running the main loop, so we can still return from here.
293+
// The WorkerThread is leaked, but we never shutdown the global pool anyway.
294+
let worker_thread = Box::into_raw(Box::new(WorkerThread::from(thread)));
295+
296+
unsafe {
297+
WorkerThread::set_current(worker_thread);
298+
Latch::set(&registry.thread_infos[index].primed);
299+
}
300+
continue;
301+
}
302+
303303
if let Err(e) = builder.get_spawn_handler().spawn(thread) {
304304
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
305305
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use rayon_core::ThreadPoolBuilder;
2+
use std::sync::{Arc, Condvar, Mutex};
3+
use std::thread::{self, JoinHandle};
4+
5+
#[test]
6+
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
7+
fn use_current_thread_basic() {
8+
static JOIN_HANDLES: Mutex<Vec<JoinHandle<()>>> = Mutex::new(Vec::new());
9+
let pool = ThreadPoolBuilder::new()
10+
.num_threads(2)
11+
.use_current_thread()
12+
.spawn_handler(|builder| {
13+
let handle = thread::Builder::new().spawn(|| builder.run())?;
14+
JOIN_HANDLES.lock().unwrap().push(handle);
15+
Ok(())
16+
})
17+
.build()
18+
.unwrap();
19+
assert_eq!(rayon_core::current_thread_index(), Some(0));
20+
assert_eq!(
21+
JOIN_HANDLES.lock().unwrap().len(),
22+
1,
23+
"Should only spawn one extra thread"
24+
);
25+
26+
let another_pool = ThreadPoolBuilder::new()
27+
.num_threads(2)
28+
.use_current_thread()
29+
.build();
30+
assert!(
31+
another_pool.is_err(),
32+
"Should error if the thread is already part of a pool"
33+
);
34+
35+
let pair = Arc::new((Mutex::new(false), Condvar::new()));
36+
let pair2 = Arc::clone(&pair);
37+
pool.spawn(move || {
38+
assert_ne!(rayon_core::current_thread_index(), Some(0));
39+
// This should execute even if the current thread is blocked, since we have two threads in
40+
// the pool.
41+
let &(ref started, ref condvar) = &*pair2;
42+
*started.lock().unwrap() = true;
43+
condvar.notify_one();
44+
});
45+
46+
let _guard = pair
47+
.1
48+
.wait_while(pair.0.lock().unwrap(), |ran| !*ran)
49+
.unwrap();
50+
std::mem::drop(pool); // Drop the pool.
51+
52+
// Wait until all threads have actually exited. This is not really needed, other than to
53+
// reduce noise of leak-checking tools.
54+
for handle in std::mem::take(&mut *JOIN_HANDLES.lock().unwrap()) {
55+
let _ = handle.join();
56+
}
57+
}

0 commit comments

Comments
 (0)