Skip to content

Commit 96c9088

Browse files
committed
IO trampoline async (arrow-kt#1448)
1 parent be8deb0 commit 96c9088

File tree

4 files changed

+68
-18
lines changed

4 files changed

+68
-18
lines changed

modules/effects/arrow-effects-data/src/main/kotlin/arrow/effects/IO.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ sealed class IO<out A> : IOOf<A> {
202202
override fun unsafeRunTimedTotal(limit: Duration): Option<A> = throw AssertionError("Unreachable")
203203
}
204204

205-
internal data class Async<out A>(val k: IOProc<A>) : IO<A>() {
205+
internal data class Async<out A>(val shouldTrampoline: Boolean = false, val k: IOProc<A>) : IO<A>() {
206206
override fun unsafeRunTimedTotal(limit: Duration): Option<A> = unsafeResync(this, limit)
207207
}
208208

modules/effects/arrow-effects-data/src/main/kotlin/arrow/effects/IORunLoop.kt

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import arrow.core.NonFatal
66
import arrow.core.Right
77
import arrow.core.nonFatalOrThrow
88
import arrow.effects.internal.ArrowInternalException
9+
import arrow.effects.internal.Platform
910
import arrow.effects.internal.Platform.ArrayStack
1011
import kotlin.coroutines.CoroutineContext
1112
import kotlin.coroutines.EmptyCoroutineContext
@@ -64,7 +65,7 @@ internal object IORunLoop {
6465
hasResult = true
6566
currentIO = null
6667
} catch (t: Throwable) {
67-
currentIO = IO.RaiseError(t.nonFatalOrThrow())
68+
currentIO = IO.RaiseError(t.nonFatalOrThrow())
6869
}
6970
}
7071
is IO.Async -> {
@@ -371,6 +372,12 @@ internal object IORunLoop {
371372
private var bFirst: BindF? = null
372373
private var bRest: CallStack? = null
373374

375+
private var contIndex: Int = 0
376+
private var trampolineAfter: Boolean = false
377+
private inline val shouldTrampoline inline get() = trampolineAfter || contIndex == Platform.maxStackDepthSize
378+
379+
private var value: IO<Any?>? = null
380+
374381
fun contextSwitch(conn: IOConnection) {
375382
this.conn = conn
376383
}
@@ -379,10 +386,12 @@ internal object IORunLoop {
379386
canCall = true
380387
this.bFirst = bFirst
381388
this.bRest = bRest
389+
contIndex++
382390
}
383391

384392
fun start(async: IO.Async<Any?>, bFirst: BindF?, bRest: CallStack?) {
385393
prepare(bFirst, bRest)
394+
trampolineAfter = async.shouldTrampoline
386395
async.k(conn, this)
387396
}
388397

@@ -392,12 +401,30 @@ internal object IORunLoop {
392401
effect.effect.startCoroutine(this)
393402
}
394403

404+
private fun signal(result: IO<Any?>) {
405+
// Allow GC to collect
406+
val bFirst = this.bFirst
407+
val bRest = this.bRest
408+
this.bFirst = null
409+
this.bRest = null
410+
this._context = EmptyCoroutineContext
411+
412+
loop(result, conn, cb, this, bFirst, bRest)
413+
}
414+
395415
override operator fun invoke(either: Either<Throwable, Any?>) {
396416
if (canCall) {
397417
canCall = false
398418
when (either) {
399-
is Either.Left -> loop(IO.RaiseError(either.a), conn, cb, this, bFirst, bRest)
400-
is Either.Right -> loop(IO.Pure(either.b), conn, cb, this, bFirst, bRest)
419+
is Either.Left -> IO.RaiseError(either.a)
420+
is Either.Right -> IO.Pure(either.b)
421+
}.let { r ->
422+
if (shouldTrampoline) {
423+
this.value = r
424+
Platform.trampoline { trampoline() }
425+
} else {
426+
signal(r)
427+
}
401428
}
402429
}
403430
}
@@ -406,11 +433,25 @@ internal object IORunLoop {
406433
if (canCall) {
407434
canCall = false
408435
result.fold(
409-
{ a -> loop(IO.Pure(a), conn, cb, this, bFirst, bRest) },
410-
{ e -> loop(IO.RaiseError(e), conn, cb, this, bFirst, bRest) }
411-
)
436+
{ a -> IO.Pure(a) },
437+
{ e -> IO.RaiseError(e) }
438+
).let { r ->
439+
if (shouldTrampoline) {
440+
this.value = r
441+
Platform.trampoline { trampoline() }
442+
} else {
443+
signal(r)
444+
}
445+
}
412446
}
413447
}
448+
449+
fun trampoline() {
450+
val v = value
451+
value = null
452+
contIndex = 0
453+
signal(v!!)
454+
}
414455
}
415456

416457
private class RestoreContext(

modules/effects/arrow-effects-data/src/main/kotlin/arrow/effects/internal/Utils.kt

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ object Platform {
162162
latch.tryAcquireSharedNanos(1, limit.nanoseconds)
163163
}
164164

165-
val eitherRef = ref
166-
167-
return when (eitherRef) {
165+
return when (val eitherRef = ref) {
168166
null -> None
169167
is Either.Left -> throw eitherRef.a
170168
is Either.Right -> Some(eitherRef.b)
@@ -207,7 +205,7 @@ object Platform {
207205

208206
@PublishedApi
209207
internal class TrampolineExecutor(val underlying: Executor) {
210-
private var immediateQueue = Platform.ArrayStack<Runnable>()
208+
private var immediateQueue = ArrayStack<Runnable>()
211209
@Volatile
212210
private var withinLoop = false
213211

@@ -225,7 +223,7 @@ object Platform {
225223
else immediateQueue.push(runnable)
226224

227225
private fun forkTheRest() {
228-
class ResumeRun(val head: Runnable, val rest: Platform.ArrayStack<Runnable>) : Runnable {
226+
class ResumeRun(val head: Runnable, val rest: ArrayStack<Runnable>) : Runnable {
229227
override fun run() {
230228
immediateQueue.pushAll(rest)
231229
immediateLoop(head)
@@ -235,7 +233,7 @@ object Platform {
235233
val head = immediateQueue.pop()
236234
if (head != null) {
237235
val rest = immediateQueue
238-
immediateQueue = Platform.ArrayStack()
236+
immediateQueue = ArrayStack()
239237
underlying.execute(ResumeRun(head, rest))
240238
}
241239
}

modules/effects/arrow-effects-data/src/test/kotlin/arrow/effects/IOTest.kt

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import arrow.core.None
66
import arrow.core.Right
77
import arrow.core.Some
88
import arrow.core.Tuple3
9-
import arrow.core.fix
10-
import arrow.core.flatMap
119
import arrow.core.right
1210
import arrow.effects.IO.Companion.just
1311
import arrow.effects.extensions.io.async.async
@@ -284,7 +282,7 @@ class IOTest : UnitSpec() {
284282

285283
val result =
286284
newSingleThreadContext("all").parMapN(
287-
makePar(6), IO.just(1L).order(), makePar(4), IO.defer { IO.just(2L) }.order(), makePar(5), IO { 3L }.order()) { six, one, four, two, five, three -> listOf(six, one, four, two, five, three) }
285+
makePar(6), just(1L).order(), makePar(4), IO.defer { just(2L) }.order(), makePar(5), IO { 3L }.order()) { six, one, four, two, five, three -> listOf(six, one, four, two, five, three) }
288286
.unsafeRunSync()
289287

290288
result shouldBe listOf(6L, 1, 4, 2, 5, 3)
@@ -301,7 +299,7 @@ class IOTest : UnitSpec() {
301299

302300
val result =
303301
newSingleThreadContext("all").parMapN(
304-
makePar(6), IO.just(1L), makePar(4), IO.defer { IO.just(2L) }, makePar(5), IO { 3L }) { _, _, _, _, _, _ ->
302+
makePar(6), just(1L), makePar(4), IO.defer { just(2L) }, makePar(5), IO { 3L }) { _, _, _, _, _, _ ->
305303
Thread.currentThread().name
306304
}.unsafeRunSync()
307305

@@ -466,7 +464,20 @@ class IOTest : UnitSpec() {
466464
else just(ii)
467465
}
468466

469-
IO.just(1).flatMap { ioGuaranteeCase(0) }.unsafeRunSync() shouldBe size
467+
just(1).flatMap { ioGuaranteeCase(0) }.unsafeRunSync() shouldBe size
468+
}
469+
470+
"Async should be stack safe" {
471+
val size = 5000
472+
473+
fun ioAsync(i: Int): IO<Int> = IO.async<Int> { _, cb ->
474+
cb(Right(i))
475+
}.flatMap { ii ->
476+
if (ii < size) ioAsync(ii + 1)
477+
else just(ii)
478+
}
479+
480+
IO.just(1).flatMap(::ioAsync).unsafeRunSync() shouldBe size
470481
}
471482
}
472483
}

0 commit comments

Comments
 (0)