Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,57 +47,52 @@ internal class HordeGenerationRemoteDataSource(
?.let(hordeApi::cancelRequest)
?: Completable.error(Throwable("No cached request id"))

private fun executeRequestChain(request: HordeGenerationAsyncRequest): Single<String> {
val observableChain = hordeApi
.generateAsync(request)
.flatMapObservable { asyncStartResponse ->
statusSource.id = asyncStartResponse.id
asyncStartResponse.id?.let { id ->
val pingObs = Observable
.fromSingle(hordeApi.checkGeneration(id))
.flatMap { pingResponse ->
if (pingResponse.isPossible == false) {
return@flatMap Observable.error(Throwable("Response is not possible"))
}
if (pingResponse.done == true) {
return@flatMap Observable.fromSingle(hordeApi.checkStatus(id))
}
statusSource.update(
HordeProcessStatus(
waitTimeSeconds = pingResponse.waitTime ?: 0,
queuePosition = pingResponse.queuePosition,
)
)
return@flatMap Observable.error(RetryException())
private fun executeRequestChain(request: HordeGenerationAsyncRequest) = hordeApi
.generateAsync(request)
.flatMapObservable { asyncStartResponse ->
statusSource.id = asyncStartResponse.id
asyncStartResponse.id?.let { id ->
Observable
.fromSingle(hordeApi.checkGeneration(id))
.flatMap { pingResponse ->
if (pingResponse.isPossible == false) {
return@flatMap Observable.error(Throwable("Response is not possible"))
}
if (pingResponse.done == true) {
return@flatMap Observable.fromSingle(hordeApi.checkStatus(id))
}
.retryWhen { obs ->
obs.flatMap { t ->
if (t is RetryException) {
return@flatMap Observable
.timer(HORDE_SOCKET_PING_TIME_SECONDS, TimeUnit.SECONDS)
.doOnNext {
debugLog("Retrying HORDE status check...")
}
statusSource.update(
HordeProcessStatus(
waitTimeSeconds = pingResponse.waitTime ?: 0,
queuePosition = pingResponse.queuePosition,
)
)
return@flatMap Observable.error(RetryException())
}
.retryWhen { obs ->
obs.flatMap { t ->
if (t is RetryException) Observable
.timer(HORDE_SOCKET_PING_TIME_SECONDS, TimeUnit.SECONDS)
.doOnNext {
debugLog("Retrying HORDE status check...")
}
return@flatMap Observable.error(t)
}
else
Observable.error(t)
}
}
} ?: Observable.error(Throwable("Horde returned null generation id"))
}
.flatMapSingle {
it.generations?.firstOrNull()?.let { generation ->
val bytes = URL(generation.img).readBytes()
val bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size)
Single.just(bitmap)
} ?: Single.error(Throwable("Error extracting image"))
}
.flatMapSingle { converter(BitmapToBase64Converter.Input(it)) }
.map { it.base64ImageString }
.let { Single.fromObservable(it) }

pingObs
} ?: Observable.error(Throwable("Horde returned null generation id"))
}
.flatMapSingle {
it.generations?.firstOrNull()?.let { generation ->
val bytes = URL(generation.img).readBytes()
val bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size)
Single.just(bitmap)
} ?: Single.error(Throwable("Error extracting image"))
}
.flatMapSingle { converter(BitmapToBase64Converter.Input(it)) }
.map { it.base64ImageString }

return Single.fromObservable(observableChain)
}

private class RetryException : Throwable()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import android.graphics.Bitmap
import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest
import io.reactivex.rxjava3.core.Single
import okhttp3.ResponseBody
import retrofit2.Response
import retrofit2.http.Body
import retrofit2.http.POST
import retrofit2.http.Path
Expand All @@ -22,6 +23,6 @@ interface HuggingFaceInferenceApi {
fun generate(
@Path("model") model: String,
@Body request: HuggingFaceGenerationRequest,
): Single<ResponseBody>
): Single<Response<ResponseBody>>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package com.shifthackz.aisdv1.network.api.huggingface

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import com.shifthackz.aisdv1.core.common.log.debugLog
import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import java.util.concurrent.TimeUnit

internal class HuggingFaceInferenceApiImpl(
private val rawApi: HuggingFaceInferenceApi.RawApi,
Expand All @@ -14,8 +17,33 @@ internal class HuggingFaceInferenceApiImpl(
request: HuggingFaceGenerationRequest,
): Single<Bitmap> = rawApi
.generate(model, request)
.map { body ->
val bytes = body.bytes()
BitmapFactory.decodeByteArray(bytes, 0, bytes.size)
.flatMapObservable { response ->
if (response.isSuccessful) {
response.body()
?.bytes()
?.let { BitmapFactory.decodeByteArray(it, 0, it.size) }
?.let { Observable.just(it) }
?: Observable.error(Throwable("Body is null"))
} else {
when (response.code()) {
503 -> Observable.error(RetryException())

else -> {
Observable.error(Throwable(response.errorBody()?.string().toString()))
}
}
}
}
.retryWhen { obs ->
obs.flatMap { t ->
if (t is RetryException) Observable
.timer(20L, TimeUnit.SECONDS)
.doOnNext { debugLog("Retrying hugging face due to 503...") }
else
Observable.error(t)
}
}
.let { Single.fromObservable(it) }

private class RetryException : Throwable()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.shifthackz.aisdv1.network.response

import com.google.gson.annotations.SerializedName

data class HuggingFaceErrorResponse(
@SerializedName("error")
val error: String?,
@SerializedName("estimated_time")
val estimatedTime: Double?,
)