diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt index 48adc623..1d8a14c2 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt @@ -47,57 +47,52 @@ internal class HordeGenerationRemoteDataSource( ?.let(hordeApi::cancelRequest) ?: Completable.error(Throwable("No cached request id")) - private fun executeRequestChain(request: HordeGenerationAsyncRequest): Single { - val observableChain = hordeApi - .generateAsync(request) - .flatMapObservable { asyncStartResponse -> - statusSource.id = asyncStartResponse.id - asyncStartResponse.id?.let { id -> - val pingObs = Observable - .fromSingle(hordeApi.checkGeneration(id)) - .flatMap { pingResponse -> - if (pingResponse.isPossible == false) { - return@flatMap Observable.error(Throwable("Response is not possible")) - } - if (pingResponse.done == true) { - return@flatMap Observable.fromSingle(hordeApi.checkStatus(id)) - } - statusSource.update( - HordeProcessStatus( - waitTimeSeconds = pingResponse.waitTime ?: 0, - queuePosition = pingResponse.queuePosition, - ) - ) - return@flatMap Observable.error(RetryException()) + private fun executeRequestChain(request: HordeGenerationAsyncRequest) = hordeApi + .generateAsync(request) + .flatMapObservable { asyncStartResponse -> + statusSource.id = asyncStartResponse.id + asyncStartResponse.id?.let { id -> + Observable + .fromSingle(hordeApi.checkGeneration(id)) + .flatMap { pingResponse -> + if (pingResponse.isPossible == false) { + return@flatMap Observable.error(Throwable("Response is not possible")) + } + if (pingResponse.done == true) { + return@flatMap Observable.fromSingle(hordeApi.checkStatus(id)) } - .retryWhen { obs -> - obs.flatMap { t -> - if (t is RetryException) { - return@flatMap Observable - .timer(HORDE_SOCKET_PING_TIME_SECONDS, TimeUnit.SECONDS) - .doOnNext { - debugLog("Retrying HORDE status check...") - } + statusSource.update( + HordeProcessStatus( + waitTimeSeconds = pingResponse.waitTime ?: 0, + queuePosition = pingResponse.queuePosition, + ) + ) + return@flatMap Observable.error(RetryException()) + } + .retryWhen { obs -> + obs.flatMap { t -> + if (t is RetryException) Observable + .timer(HORDE_SOCKET_PING_TIME_SECONDS, TimeUnit.SECONDS) + .doOnNext { + debugLog("Retrying HORDE status check...") } - return@flatMap Observable.error(t) - } + else + Observable.error(t) } + } + } ?: Observable.error(Throwable("Horde returned null generation id")) + } + .flatMapSingle { + it.generations?.firstOrNull()?.let { generation -> + val bytes = URL(generation.img).readBytes() + val bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size) + Single.just(bitmap) + } ?: Single.error(Throwable("Error extracting image")) + } + .flatMapSingle { converter(BitmapToBase64Converter.Input(it)) } + .map { it.base64ImageString } + .let { Single.fromObservable(it) } - pingObs - } ?: Observable.error(Throwable("Horde returned null generation id")) - } - .flatMapSingle { - it.generations?.firstOrNull()?.let { generation -> - val bytes = URL(generation.img).readBytes() - val bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size) - Single.just(bitmap) - } ?: Single.error(Throwable("Error extracting image")) - } - .flatMapSingle { converter(BitmapToBase64Converter.Input(it)) } - .map { it.base64ImageString } - - return Single.fromObservable(observableChain) - } private class RetryException : Throwable() diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApi.kt index 0e09a264..f18a3bea 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApi.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApi.kt @@ -4,6 +4,7 @@ import android.graphics.Bitmap import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest import io.reactivex.rxjava3.core.Single import okhttp3.ResponseBody +import retrofit2.Response import retrofit2.http.Body import retrofit2.http.POST import retrofit2.http.Path @@ -22,6 +23,6 @@ interface HuggingFaceInferenceApi { fun generate( @Path("model") model: String, @Body request: HuggingFaceGenerationRequest, - ): Single + ): Single> } } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApiImpl.kt index 845e91e3..27b3d2ca 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApiImpl.kt @@ -2,8 +2,11 @@ package com.shifthackz.aisdv1.network.api.huggingface import android.graphics.Bitmap import android.graphics.BitmapFactory +import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest +import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single +import java.util.concurrent.TimeUnit internal class HuggingFaceInferenceApiImpl( private val rawApi: HuggingFaceInferenceApi.RawApi, @@ -14,8 +17,33 @@ internal class HuggingFaceInferenceApiImpl( request: HuggingFaceGenerationRequest, ): Single = rawApi .generate(model, request) - .map { body -> - val bytes = body.bytes() - BitmapFactory.decodeByteArray(bytes, 0, bytes.size) + .flatMapObservable { response -> + if (response.isSuccessful) { + response.body() + ?.bytes() + ?.let { BitmapFactory.decodeByteArray(it, 0, it.size) } + ?.let { Observable.just(it) } + ?: Observable.error(Throwable("Body is null")) + } else { + when (response.code()) { + 503 -> Observable.error(RetryException()) + + else -> { + Observable.error(Throwable(response.errorBody()?.string().toString())) + } + } + } + } + .retryWhen { obs -> + obs.flatMap { t -> + if (t is RetryException) Observable + .timer(20L, TimeUnit.SECONDS) + .doOnNext { debugLog("Retrying hugging face due to 503...") } + else + Observable.error(t) + } } + .let { Single.fromObservable(it) } + + private class RetryException : Throwable() } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/HuggingFaceErrorResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/HuggingFaceErrorResponse.kt new file mode 100644 index 00000000..1974dcd4 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/HuggingFaceErrorResponse.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName + +data class HuggingFaceErrorResponse( + @SerializedName("error") + val error: String?, + @SerializedName("estimated_time") + val estimatedTime: Double?, +)