Skip to content

Commit 19d112c

Browse files
authored
IO rewrite ContinueOn impl (#1443)
1 parent 2f540d6 commit 19d112c

File tree

2 files changed

+58
-43
lines changed

2 files changed

+58
-43
lines changed

modules/effects/arrow-effects-data/src/main/kotlin/arrow/effects/IO.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ sealed class IO<out A> : IOOf<A> {
206206
override fun unsafeRunTimedTotal(limit: Duration): Option<A> = unsafeResync(this, limit)
207207
}
208208

209+
internal data class Effect<out A>(val ctx: CoroutineContext, val effect: suspend () -> A) : IO<A>() {
210+
override fun unsafeRunTimedTotal(limit: Duration): Option<A> = unsafeResync(this, limit)
211+
}
212+
209213
internal data class Bind<E, out A>(val cont: IO<E>, val g: (E) -> IO<A>) : IO<A>() {
210214
override fun unsafeRunTimedTotal(limit: Duration): Option<A> = throw AssertionError("Unreachable")
211215
}

modules/effects/arrow-effects-data/src/main/kotlin/arrow/effects/IORunLoop.kt

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package arrow.effects
22

3-
import arrow.core.Continuation
43
import arrow.core.Either
54
import arrow.core.Left
65
import arrow.core.NonFatal
76
import arrow.core.Right
87
import arrow.effects.internal.Platform.ArrayStack
98
import kotlin.coroutines.CoroutineContext
9+
import kotlin.coroutines.EmptyCoroutineContext
1010
import kotlin.coroutines.startCoroutine
1111

1212
private typealias Current = IOOf<Any?>
@@ -71,7 +71,10 @@ internal object IORunLoop {
7171
}
7272
is IO.Async -> {
7373
// Return case for Async operations
74-
return suspendInAsync(currentIO as IO<A>, bFirst, bRest, currentIO.k)
74+
return suspendInAsync(currentIO as IO<A>, bFirst, bRest)
75+
}
76+
is IO.Effect -> {
77+
return suspendInAsync(currentIO as IO<A>, bFirst, bRest)
7578
}
7679
is IO.Bind<*, *> -> {
7780
if (bFirst != null) {
@@ -87,15 +90,13 @@ internal object IORunLoop {
8790
bRest.push(bFirst)
8891
}
8992
val localCurrent = currentIO
90-
9193
val currentCC = localCurrent.cc
92-
9394
val localCont = currentIO.cont
9495

9596
bFirst = { c: Any? -> IO.just(c) }
9697

97-
currentIO = IO.async { conn, cc ->
98-
loop(localCont, conn, cc.asyncCallback(currentCC), null, null, null)
98+
currentIO = IO.Bind(localCont) { a ->
99+
IO.Effect(currentCC) { a }
99100
}
100101
}
101102
is IO.Map<*, *> -> {
@@ -136,20 +137,15 @@ internal object IORunLoop {
136137
private fun <A> suspendInAsync(
137138
currentIO: IO<A>,
138139
bFirst: BindF?,
139-
bRest: CallStack?,
140-
register: IOProc<Any?>
140+
bRest: CallStack?
141141
): IO<A> =
142-
// Hitting an async boundary means we have to stop, however
143-
// if we had previous `flatMap` operations then we need to resume
144-
// the loop with the collected stack
145-
when {
146-
bFirst != null || (bRest != null && bRest.isNotEmpty()) ->
147-
IO.Async { conn, cb ->
148-
val rcb = RestartCallback(conn, cb as Callback)
149-
rcb.prepare(bFirst, bRest)
150-
register(conn, rcb)
151-
}
152-
else -> currentIO
142+
// Hitting an async boundary means we have to stop, however if we had previous `flatMap` operations then we need to resume the loop with the collected stack
143+
if (bFirst != null || (bRest != null && bRest.isNotEmpty())) {
144+
IO.Async { conn, cb ->
145+
loop(currentIO, conn, cb as Callback, null, bFirst, bRest)
146+
}
147+
} else {
148+
currentIO
153149
}
154150

155151
private fun loop(
@@ -216,9 +212,18 @@ internal object IORunLoop {
216212
if (rcb == null) {
217213
rcb = RestartCallback(conn, cb)
218214
}
219-
rcb.prepare(bFirst, bRest)
215+
216+
// Return case for Async operations
217+
rcb.start(currentIO, bFirst, bRest)
218+
return
219+
}
220+
is IO.Effect -> {
221+
if (rcb == null) {
222+
rcb = RestartCallback(conn, cb)
223+
}
224+
220225
// Return case for Async operations
221-
currentIO.k(conn, rcb)
226+
rcb.start(currentIO, bFirst, bRest)
222227
return
223228
}
224229
is IO.Bind<*, *> -> {
@@ -235,15 +240,13 @@ internal object IORunLoop {
235240
bRest.push(bFirst)
236241
}
237242
val localCurrent = currentIO
238-
239243
val currentCC = localCurrent.cc
240-
241244
val localCont = currentIO.cont
242245

243246
bFirst = { c: Any? -> IO.just(c) }
244247

245-
currentIO = IO.async { _, callback ->
246-
loop(localCont, conn, callback.asyncCallback(currentCC), null, null, null)
248+
currentIO = IO.Bind(localCont) { a ->
249+
IO.Effect(currentCC) { a }
247250
}
248251
}
249252
is IO.Map<*, *> -> {
@@ -353,7 +356,12 @@ internal object IORunLoop {
353356
* A `RestartCallback` gets created only once, per [startCancelable] (`unsafeRunAsync`) invocation, once an `Async`
354357
* state is hit, its job being to resume the loop after the boundary, but with the bind call-stack restored.
355358
*/
356-
private data class RestartCallback(val connInit: IOConnection, val cb: Callback) : Callback {
359+
private data class RestartCallback(val connInit: IOConnection, val cb: Callback) : Callback, kotlin.coroutines.Continuation<Any?> {
360+
361+
// Nasty trick to re-use `Continuation` with different CC.
362+
private var _context: CoroutineContext = EmptyCoroutineContext
363+
override val context: CoroutineContext
364+
get() = _context
357365

358366
private var conn: IOConnection = connInit
359367
private var canCall = false
@@ -364,12 +372,23 @@ internal object IORunLoop {
364372
this.conn = conn
365373
}
366374

367-
fun prepare(bFirst: BindF?, bRest: CallStack?) {
375+
private fun prepare(bFirst: BindF?, bRest: CallStack?) {
368376
canCall = true
369377
this.bFirst = bFirst
370378
this.bRest = bRest
371379
}
372380

381+
fun start(async: IO.Async<Any?>, bFirst: BindF?, bRest: CallStack?) {
382+
prepare(bFirst, bRest)
383+
async.k(conn, this)
384+
}
385+
386+
fun start(effect: IO.Effect<Any?>, bFirst: BindF?, bRest: CallStack?) {
387+
prepare(bFirst, bRest)
388+
this._context = effect.ctx
389+
effect.effect.startCoroutine(this)
390+
}
391+
373392
override operator fun invoke(either: Either<Throwable, Any?>) {
374393
if (canCall) {
375394
canCall = false
@@ -379,25 +398,17 @@ internal object IORunLoop {
379398
}
380399
}
381400
}
382-
}
383-
384-
private fun <T> ((Either<Throwable, T>) -> Unit).asyncCallback(currentCC: CoroutineContext): (Either<Throwable, T>) -> Unit =
385-
{ result ->
386-
val func: suspend () -> Unit = { this(result) }
387-
388-
val normalResume: Continuation<Unit> = object : Continuation<Unit> {
389-
override val context: CoroutineContext = currentCC
390401

391-
override fun resume(value: Unit) {
392-
}
393-
394-
override fun resumeWithException(exception: Throwable) {
395-
this@asyncCallback(Either.left(exception))
396-
}
402+
override fun resumeWith(result: Result<Any?>) {
403+
if (canCall) {
404+
canCall = false
405+
result.fold(
406+
{ a -> loop(IO.Pure(a), conn, cb, this, bFirst, bRest) },
407+
{ e -> loop(IO.RaiseError(e), conn, cb, this, bFirst, bRest) }
408+
)
397409
}
398-
399-
func.startCoroutine(normalResume)
400410
}
411+
}
401412

402413
private class RestoreContext(
403414
val old: IOConnection,

0 commit comments

Comments
 (0)