From 97f579a676c025d822cef91d8bd8bfa15d1b61b7 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sat, 3 Aug 2024 21:02:14 +0300 Subject: [PATCH 01/11] Swarm UI working prototype implementation --- .../aisdv1/data/di/RemoteDataSourceModule.kt | 13 ++- .../aisdv1/data/di/RepositoryModule.kt | 3 + .../SwarmUiTextToImagePayloadMappers.kt | 38 ++++++ .../data/preference/PreferenceManagerImpl.kt | 8 ++ .../data/preference/SessionPreferenceImpl.kt | 6 + .../SwarmUiGenerationRemoteDataSource.kt | 57 +++++++++ .../SwarmUiGenerationRepositoryImpl.kt | 51 ++++++++ .../datasource/SwarmUiGenerationDataSource.kt | 14 +++ .../aisdv1/domain/di/DomainModule.kt | 6 + .../aisdv1/domain/entity/Configuration.kt | 1 + .../aisdv1/domain/entity/FeatureTag.kt | 1 + .../aisdv1/domain/entity/ServerSource.kt | 14 +++ .../settings/SetupConnectionInterActor.kt | 2 + .../settings/SetupConnectionInterActorImpl.kt | 2 + .../domain/preference/PreferenceManager.kt | 1 + .../domain/preference/SessionPreference.kt | 1 + .../repository/SwarmUiGenerationRepository.kt | 12 ++ .../TestSwarmUiConnectivityUseCase.kt | 7 ++ .../TestSwarmUiConnectivityUseCaseImpl.kt | 10 ++ .../generation/TextToImageUseCaseImpl.kt | 3 + .../settings/ConnectToSwarmUiUseCase.kt | 8 ++ .../settings/ConnectToSwarmUiUseCaseImpl.kt | 43 +++++++ .../settings/GetConfigurationUseCaseImpl.kt | 1 + .../SetServerConfigurationUseCaseImpl.kt | 2 + .../aisdv1/network/api/swarmui/SwarmUiApi.kt | 47 ++++++++ .../network/api/swarmui/SwarmUiApiImpl.kt | 26 +++++ .../aisdv1/network/di/NetworkModule.kt | 9 ++ .../network/interceptor/LoggingInterceptor.kt | 3 +- .../request/SwarmUiGenerationRequest.kt | 18 +++ .../response/SwarmUiGenerationResponse.kt | 8 ++ .../response/SwarmUiSessionResponse.kt | 8 ++ .../screen/img2img/ImageToImageScreen.kt | 3 +- .../screen/settings/SettingsScreen.kt | 1 + .../screen/setup/ServerSetupIntent.kt | 2 + .../screen/setup/ServerSetupState.kt | 2 + .../screen/setup/ServerSetupViewModel.kt | 110 +++++++++++------- .../components/ConfigurationModeButton.kt | 5 +- .../screen/setup/forms/AuthCredentialsForm.kt | 99 ++++++++++++++++ .../screen/setup/forms/Automatic1111Form.kt | 77 +----------- .../screen/setup/forms/SwarmUiForm.kt | 60 ++++++++++ .../screen/setup/mappers/FeatureTagMapper.kt | 1 + .../screen/setup/steps/ConfigurationStep.kt | 6 + .../widget/engine/EngineSelectionComponent.kt | 1 + .../widget/input/GenerationInputForm.kt | 1 + presentation/src/main/res/values/strings.xml | 5 + 45 files changed, 674 insertions(+), 122 deletions(-) create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt index 05d54d9a..d3eb6205 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt @@ -20,6 +20,7 @@ import com.shifthackz.aisdv1.data.remote.StableDiffusionHyperNetworksRemoteDataS import com.shifthackz.aisdv1.data.remote.StableDiffusionLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionSamplersRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiGenerationRemoteDataSource import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource @@ -36,6 +37,7 @@ import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataS import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.gateway.ServerConnectivityGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager @@ -51,8 +53,12 @@ val remoteDataSourceModule = module { single { ServerUrlProvider { endpoint -> val prefs = get() - Single - .fromCallable(prefs::serverUrl) + val chain = if (prefs.source == ServerSource.SWARM_UI) { + Single.fromCallable(prefs::swarmServerUrl) + } else { + Single.fromCallable(prefs::serverUrl) + } + chain .map(String::fixUrlSlashes) .map { baseUrl -> "$baseUrl/$endpoint" } } @@ -61,6 +67,7 @@ val remoteDataSourceModule = module { factoryOf(::HordeGenerationRemoteDataSource) bind HordeGenerationDataSource.Remote::class factoryOf(::HuggingFaceGenerationRemoteDataSource) bind HuggingFaceGenerationDataSource.Remote::class factoryOf(::OpenAiGenerationRemoteDataSource) bind OpenAiGenerationDataSource.Remote::class + factoryOf(::SwarmUiGenerationRemoteDataSource) bind SwarmUiGenerationDataSource.Remote::class factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class @@ -78,7 +85,7 @@ val remoteDataSourceModule = module { factory { val lambda: () -> Boolean = { val prefs = get() - prefs.source == ServerSource.AUTOMATIC1111 + prefs.source == ServerSource.AUTOMATIC1111 || prefs.source == ServerSource.SWARM_UI } val monitor = get { parametersOf(lambda) } ServerConnectivityGatewayImpl(monitor, get()) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index ea0d16ca..7e61b96e 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -20,6 +20,7 @@ import com.shifthackz.aisdv1.data.repository.StableDiffusionHyperNetworksReposit import com.shifthackz.aisdv1.data.repository.StableDiffusionLorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionSamplersRepositoryImpl +import com.shifthackz.aisdv1.data.repository.SwarmUiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.TemporaryGenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.WakeLockRepositoryImpl import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository @@ -40,6 +41,7 @@ import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepos import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import com.shifthackz.aisdv1.domain.repository.TemporaryGenerationResultRepository import com.shifthackz.aisdv1.domain.repository.WakeLockRepository import org.koin.android.ext.koin.androidContext @@ -60,6 +62,7 @@ val repositoryModule = module { factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class + factoryOf(::SwarmUiGenerationRepositoryImpl) bind SwarmUiGenerationRepository::class factoryOf(::StabilityAiGenerationRepositoryImpl) bind StabilityAiGenerationRepository::class factoryOf(::StabilityAiCreditsRepositoryImpl) bind StabilityAiCreditsRepository::class factoryOf(::StabilityAiEnginesRepositoryImpl) bind StabilityAiEnginesRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt new file mode 100644 index 00000000..2998c824 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt @@ -0,0 +1,38 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest + +fun TextToImagePayload.mapToSwarmUiRequest(sessionId: String): SwarmUiGenerationRequest = with(this) { + SwarmUiGenerationRequest( + sessionId = sessionId, + model = "OfficialStableDiffusion/sd_xl_base_1.0", + images = 1, + prompt = prompt, + width = width, + height = height, + ) +} +// +//fun Pair.mapToAiGenResult(): AiGenerationResult = +// let { (payload, response) -> +// AiGenerationResult( +// id = 0L, +// image = response.images?.firstOrNull() ?: "", +// inputImage = "", +// createdAt = Date(), +// type = AiGenerationResult.Type.TEXT_TO_IMAGE, +// denoisingStrength = 0f, +// prompt = payload.prompt, +// negativePrompt = payload.negativePrompt, +// width = payload.width, +// height = payload.height, +// samplingSteps = payload.samplingSteps, +// cfgScale = payload.cfgScale, +// restoreFaces = payload.restoreFaces, +// sampler = payload.sampler, +// seed = payload.seed, +// subSeed = payload.subSeed, +// subSeedStrength = payload.subSeedStrength, +// ) +// } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index 20d91c74..e250e4d6 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -27,6 +27,13 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } + override var swarmServerUrl: String + get() = (preferences.getString(KEY_SWARM_SERVER_URL, "") ?: "").fixUrlSlashes() + set(value) = preferences.edit() + .putString(KEY_SWARM_SERVER_URL, value.fixUrlSlashes()) + .apply() + .also { onPreferencesChanged() } + override var demoMode: Boolean get() = preferences.getBoolean(KEY_DEMO_MODE, false) set(value) = preferences.edit() @@ -215,6 +222,7 @@ class PreferenceManagerImpl( companion object { const val KEY_SERVER_URL = "key_server_url" + const val KEY_SWARM_SERVER_URL = "key_swarm_server_url" const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" const val KEY_AI_AUTO_SAVE = "key_ai_auto_save" diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt index 2fb82077..b076d72a 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt @@ -5,10 +5,16 @@ import com.shifthackz.aisdv1.domain.preference.SessionPreference class SessionPreferenceImpl : SessionPreference { private var _coinsPerDay: Int = -1 + private var _swarmUiSessionId: String = "" override var coinsPerDay: Int get() = _coinsPerDay set(value) { _coinsPerDay = value } + override var swarmUiSessionId: String + get() = _swarmUiSessionId + set(value) { + _swarmUiSessionId = value + } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt new file mode 100644 index 00000000..1a394a14 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt @@ -0,0 +1,57 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.mappers.mapCloudToAiGenResult +import com.shifthackz.aisdv1.data.mappers.mapToSwarmUiRequest +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_SESSION +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_TXT_TO_IMG +import io.reactivex.rxjava3.core.Single + +class SwarmUiGenerationRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, + private val converter: BitmapToBase64Converter, +) : SwarmUiGenerationDataSource.Remote { + + override fun getNewSession() = PATH_SESSION + .let(serverUrlProvider::invoke) + .flatMap(::getSessionForUrl) + + override fun getNewSession(url: String) = + getSessionForUrl("${url.fixUrlSlashes()}/$PATH_SESSION") + + override fun textToImage( + sessionId: String, + payload: TextToImagePayload, + ) = serverUrlProvider(PATH_TXT_TO_IMG) + .flatMap { url -> api.textToImage(url, payload.mapToSwarmUiRequest(sessionId)) } + .flatMap { response -> + serverUrlProvider("").map { url -> response to url } + } + .flatMap { (response, url) -> + response.images + ?.firstOrNull() + ?.let { endpoint -> Single.just("$url/$endpoint".fixUrlSlashes()) } + ?: Single.error(IllegalStateException("Bad response")) + } + .flatMap(api::downloadImage) + .map(BitmapToBase64Converter::Input) + .flatMap(converter::invoke) + .map(BitmapToBase64Converter.Output::base64ImageString) + .map { base64 -> payload to base64 } + .map(Pair::mapCloudToAiGenResult) + + private fun getSessionForUrl(url: String) = api + .getNewSession(url) + .flatMap { response -> + response.sessionId + ?.takeIf(String::isNotBlank) + ?.let { Single.just(it) } + ?: Single.error(IllegalStateException("Bad session ID.")) + } +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt new file mode 100644 index 00000000..8d4a2858 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt @@ -0,0 +1,51 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.core.CoreGenerationRepository +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.preference.SessionPreference +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiGenerationRepositoryImpl( + mediaStoreGateway: MediaStoreGateway, + base64ToBitmapConverter: Base64ToBitmapConverter, + localDataSource: GenerationResultDataSource.Local, + private val remoteDataSource: SwarmUiGenerationDataSource.Remote, + private val preferenceManager: PreferenceManager, + private val sessionPreference: SessionPreference, +) : CoreGenerationRepository( + mediaStoreGateway, + base64ToBitmapConverter, + localDataSource, + preferenceManager, +), SwarmUiGenerationRepository { + + override fun checkApiAvailability() = obtainSessionId() + .ignoreElement() + + override fun checkApiAvailability(url: String) = obtainSessionId(url) + .ignoreElement() + + override fun generateFromText(payload: TextToImagePayload) = obtainSessionId() + .flatMap { sessionId -> remoteDataSource.textToImage(sessionId, payload) } + .flatMap(::insertGenerationResult) + + private fun obtainSessionId(connectUrl: String? = null) = + if (sessionPreference.swarmUiSessionId.isBlank()) { + val chain = connectUrl + ?.let(remoteDataSource::getNewSession) + ?: remoteDataSource.getNewSession() + + chain.map { sessionId -> + sessionPreference.swarmUiSessionId = sessionId + sessionId + } + } else { + Single.just(sessionPreference.swarmUiSessionId) + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt new file mode 100644 index 00000000..bafca9ab --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +sealed interface SwarmUiGenerationDataSource { + + interface Remote { + fun getNewSession(): Single + fun getNewSession(url: String): Single + fun textToImage(sessionId: String, payload: TextToImagePayload): Single + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index 23338d53..e2823a4d 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -26,6 +26,8 @@ import com.shifthackz.aisdv1.domain.usecase.connectivity.TestOpenAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestOpenAiApiKeyUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.connectivity.TestStabilityAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestStabilityAiApiKeyUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCase +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCase import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase @@ -86,6 +88,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.SetServerConfigurationUseCase @@ -128,6 +132,7 @@ internal val useCasesModule = module { factoryOf(::TestHuggingFaceApiKeyUseCaseImpl) bind TestHuggingFaceApiKeyUseCase::class factoryOf(::TestOpenAiApiKeyUseCaseImpl) bind TestOpenAiApiKeyUseCase::class factoryOf(::TestStabilityAiApiKeyUseCaseImpl) bind TestStabilityAiApiKeyUseCase::class + factoryOf(::TestSwarmUiConnectivityUseCaseImpl) bind TestSwarmUiConnectivityUseCase::class factoryOf(::SaveGenerationResultUseCaseImpl) bind SaveGenerationResultUseCase::class factoryOf(::ObserveSeverConnectivityUseCaseImpl) bind ObserveSeverConnectivityUseCase::class factoryOf(::ObserveHordeProcessStatusUseCaseImpl) bind ObserveHordeProcessStatusUseCase::class @@ -146,6 +151,7 @@ internal val useCasesModule = module { factoryOf(::ConnectToHordeUseCaseImpl) bind ConnectToHordeUseCase::class factoryOf(::ConnectToLocalDiffusionUseCaseImpl) bind ConnectToLocalDiffusionUseCase::class factoryOf(::ConnectToA1111UseCaseImpl) bind ConnectToA1111UseCase::class + factoryOf(::ConnectToSwarmUiUseCaseImpl) bind ConnectToSwarmUiUseCase::class factoryOf(::ConnectToHuggingFaceUseCaseImpl) bind ConnectToHuggingFaceUseCase::class factoryOf(::ConnectToOpenAiUseCaseImpl) bind ConnectToOpenAiUseCase::class factoryOf(::ConnectToStabilityAiUseCaseImpl) bind ConnectToStabilityAiUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index 1f2007d9..ab8a4200 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials data class Configuration( val serverUrl: String = "", + val swarmUiUrl: String = "", val demoMode: Boolean = false, val source: ServerSource = ServerSource.AUTOMATIC1111, val hordeApiKey: String = "", diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt index 370da5a3..d4c7db9b 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.entity enum class FeatureTag { Txt2Img, Img2Img, + OwnServer, Lora, TextualInversion, HyperNetworks, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 067e9e51..dae3ebc7 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -9,6 +9,7 @@ enum class ServerSource( featureTags = setOf( FeatureTag.Txt2Img, FeatureTag.Img2Img, + FeatureTag.OwnServer, FeatureTag.MultipleModels, FeatureTag.Lora, FeatureTag.TextualInversion, @@ -16,6 +17,19 @@ enum class ServerSource( FeatureTag.Batch, ), ), + SWARM_UI( + key = "swarm_ui", + featureTags = setOf( + FeatureTag.Txt2Img, + FeatureTag.OwnServer, +// FeatureTag.Img2Img, + FeatureTag.MultipleModels, +// FeatureTag.Lora, +// FeatureTag.TextualInversion, +// FeatureTag.HyperNetworks, + FeatureTag.Batch, + ), + ), HORDE( key = "horde", featureTags = setOf( diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt index fc6dde12..4c42f81c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase interface SetupConnectionInterActor { val connectToHorde: ConnectToHordeUseCase @@ -14,4 +15,5 @@ interface SetupConnectionInterActor { val connectToHuggingFace: ConnectToHuggingFaceUseCase val connectToOpenAi: ConnectToOpenAiUseCase val connectToStabilityAi: ConnectToStabilityAiUseCase + val connectToSwarmUi: ConnectToSwarmUiUseCase } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt index 590ef434..05517631 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase internal data class SetupConnectionInterActorImpl( override val connectToHorde: ConnectToHordeUseCase, @@ -14,4 +15,5 @@ internal data class SetupConnectionInterActorImpl( override val connectToHuggingFace: ConnectToHuggingFaceUseCase, override val connectToOpenAi: ConnectToOpenAiUseCase, override val connectToStabilityAi: ConnectToStabilityAiUseCase, + override val connectToSwarmUi: ConnectToSwarmUiUseCase, ) : SetupConnectionInterActor diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index c37fd154..06311558 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -6,6 +6,7 @@ import io.reactivex.rxjava3.core.Flowable interface PreferenceManager { var serverUrl: String + var swarmServerUrl: String var demoMode: Boolean var monitorConnectivity: Boolean var autoSaveAiResults: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt index 03bf43d1..ab918e77 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt @@ -2,4 +2,5 @@ package com.shifthackz.aisdv1.domain.preference interface SessionPreference { var coinsPerDay: Int + var swarmUiSessionId: String } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt new file mode 100644 index 00000000..bf4bb83e --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface SwarmUiGenerationRepository { + fun checkApiAvailability(): Completable + fun checkApiAvailability(url: String): Completable + fun generateFromText(payload: TextToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt new file mode 100644 index 00000000..6a29bc0d --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import io.reactivex.rxjava3.core.Completable + +interface TestSwarmUiConnectivityUseCase { + operator fun invoke(url: String): Completable +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt new file mode 100644 index 00000000..e79cc0e5 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository + +class TestSwarmUiConnectivityUseCaseImpl( + private val repository: SwarmUiGenerationRepository, +) : TestSwarmUiConnectivityUseCase { + + override fun invoke(url: String) = repository.checkApiAvailability(url) +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index 0d6ae856..4874171c 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -9,6 +9,7 @@ import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepositor import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Observable internal class TextToImageUseCaseImpl( @@ -17,6 +18,7 @@ internal class TextToImageUseCaseImpl( private val huggingFaceGenerationRepository: HuggingFaceGenerationRepository, private val openAiGenerationRepository: OpenAiGenerationRepository, private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, + private val swarmUiGenerationRepository: SwarmUiGenerationRepository, private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, private val preferenceManager: PreferenceManager, ) : TextToImageUseCase { @@ -33,5 +35,6 @@ internal class TextToImageUseCaseImpl( ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload) ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromText(payload) + ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromText(payload) } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt new file mode 100644 index 00000000..7d77f2a6 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import io.reactivex.rxjava3.core.Single + +interface ConnectToSwarmUiUseCase { + operator fun invoke(url: String, credentials: AuthorizationCredentials): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt new file mode 100644 index 00000000..78d771c9 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.entity.Configuration +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCase +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import java.util.concurrent.TimeUnit + +internal class ConnectToSwarmUiUseCaseImpl( + private val getConfigurationUseCase: GetConfigurationUseCase, + private val setServerConfigurationUseCase : SetServerConfigurationUseCase, + private val testSwarmUiConnectivityUseCase: TestSwarmUiConnectivityUseCase, +) : ConnectToSwarmUiUseCase { + + override fun invoke( + url: String, + credentials: AuthorizationCredentials, + ): Single> { + var configuration: Configuration? = null + return getConfigurationUseCase() + .map { originalConfiguration -> + configuration = originalConfiguration + originalConfiguration.copy( + source = ServerSource.SWARM_UI, + swarmUiUrl = url, + authCredentials = credentials, + ) + } + .flatMapCompletable(setServerConfigurationUseCase::invoke) + .delay(3L, TimeUnit.SECONDS) + .andThen(testSwarmUiConnectivityUseCase(url)) + .andThen(Single.just(Result.success(Unit))) + .timeout(30L, TimeUnit.SECONDS) + .onErrorResumeNext { t -> + val chain = configuration?.let(setServerConfigurationUseCase::invoke) + ?: Completable.complete() + + chain.andThen(Single.just(Result.failure(t))) + } + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index 67f79ad3..bcad1249 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -13,6 +13,7 @@ internal class GetConfigurationUseCaseImpl( override fun invoke(): Single = Single.just( Configuration( serverUrl = preferenceManager.serverUrl, + swarmUiUrl = preferenceManager.swarmServerUrl, demoMode = preferenceManager.demoMode, source = preferenceManager.source, hordeApiKey = preferenceManager.hordeApiKey, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index 01ad9c66..19c2d794 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -15,6 +15,7 @@ internal class SetServerConfigurationUseCaseImpl( authorizationStore.storeAuthorizationCredentials(configuration.authCredentials) preferenceManager.source = configuration.source preferenceManager.serverUrl = configuration.serverUrl + preferenceManager.swarmServerUrl = configuration.swarmUiUrl preferenceManager.demoMode = configuration.demoMode preferenceManager.hordeApiKey = configuration.hordeApiKey preferenceManager.openAiApiKey = configuration.openAiApiKey @@ -23,5 +24,6 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId preferenceManager.localModelId = configuration.localModelId + println("SWARM - set ${preferenceManager.swarmServerUrl}") } } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt new file mode 100644 index 00000000..15e55639 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.network.api.swarmui + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse +import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse +import io.reactivex.rxjava3.core.Single +import okhttp3.ResponseBody +import retrofit2.Response +import retrofit2.http.Body +import retrofit2.http.GET +import retrofit2.http.POST +import retrofit2.http.Streaming +import retrofit2.http.Url + +interface SwarmUiApi { + + fun getNewSession(url: String): Single + + fun textToImage( + @Url url: String, + @Body request: SwarmUiGenerationRequest, + ): Single + + fun downloadImage(url: String): Single + + interface RawApi { + + @POST + fun getNewSession(@Url url: String, @Body map: Map): Single + + @POST + fun textToImage( + @Url url: String, + @Body request: SwarmUiGenerationRequest, + ): Single + + @Streaming + @GET + fun download(@Url url: String): Single> + } + + companion object { + const val PATH_SESSION = "API/GetNewSession" + const val PATH_TXT_TO_IMG = "API/GenerateText2Image" + } +} diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt new file mode 100644 index 00000000..56a6dc4b --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.network.api.swarmui + +import android.graphics.BitmapFactory +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiApiImpl( + private val rawApi: SwarmUiApi.RawApi, +) : SwarmUiApi { + + override fun getNewSession(url: String) = rawApi.getNewSession(url, emptyMap()) + + override fun textToImage( + url: String, + request: SwarmUiGenerationRequest, + ) = rawApi.textToImage(url, request) + + override fun downloadImage(url: String) = rawApi.download(url) + .flatMap { response -> + response.body() + ?.bytes() + ?.let { BitmapFactory.decodeByteArray(it, 0, it.size) } + ?.let { Single.just(it) } + ?: Single.error(Throwable("Body is null")) + } +} diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt b/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt index 0995039c..bf8c4867 100755 --- a/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt @@ -14,6 +14,8 @@ import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApiImpl import com.shifthackz.aisdv1.network.api.sdai.HuggingFaceModelsApi import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApiImpl import com.shifthackz.aisdv1.network.authenticator.RestAuthenticator import com.shifthackz.aisdv1.network.connectivity.ConnectivityMonitor import com.shifthackz.aisdv1.network.error.StabilityAiErrorMapper @@ -105,6 +107,12 @@ val networkModule = module { .create(Automatic1111RestApi::class.java) } + single { + get() + .withBaseUrl(get().stableDiffusionAutomaticApiUrl) + .create(SwarmUiApi.RawApi::class.java) + } + single { get() .withBaseUrl(get().hordeApiUrl) @@ -156,6 +164,7 @@ val networkModule = module { singleOf(::ImageCdnRestApiImpl) bind ImageCdnRestApi::class singleOf(::DownloadableModelsApiImpl) bind DownloadableModelsApi::class singleOf(::HuggingFaceInferenceApiImpl) bind HuggingFaceInferenceApi::class + singleOf(::SwarmUiApiImpl) bind SwarmUiApi::class factory { params -> ConnectivityMonitor(params.get()) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt b/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt index e138632e..b592f9b5 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt @@ -8,7 +8,8 @@ internal class LoggingInterceptor { fun get() = HttpLoggingInterceptor { message -> debugLog(HTTP_TAG, message) }.apply { - level = HttpLoggingInterceptor.Level.HEADERS +// level = HttpLoggingInterceptor.Level.HEADERS + level = HttpLoggingInterceptor.Level.BODY } companion object { diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt new file mode 100644 index 00000000..b55df94a --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt @@ -0,0 +1,18 @@ +package com.shifthackz.aisdv1.network.request + +import com.google.gson.annotations.SerializedName + +data class SwarmUiGenerationRequest( + @SerializedName("session_id") + val sessionId: String, + @SerializedName("model") + val model: String, + @SerializedName("images") + val images: Int, + @SerializedName("prompt") + val prompt: String, + @SerializedName("width") + val width: Int, + @SerializedName("height") + val height: Int, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt new file mode 100644 index 00000000..ae8cdd12 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName + +data class SwarmUiGenerationResponse( + @SerializedName("images") + val images: List?, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt new file mode 100644 index 00000000..8ec5052b --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName + +data class SwarmUiSessionResponse( + @SerializedName("session_id") + val sessionId: String?, +) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt index 566905a1..08065943 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt @@ -182,8 +182,7 @@ private fun ScreenContent( } } - ServerSource.OPEN_AI, - ServerSource.LOCAL -> { + else -> { Column( modifier = Modifier .padding(paddingValues) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index 9ad999a3..06fb45ed 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -188,6 +188,7 @@ private fun ContentSettingsState( ServerSource.OPEN_AI -> R.string.srv_type_open_ai ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai ServerSource.LOCAL -> R.string.srv_type_local_short + ServerSource.SWARM_UI -> R.string.srv_type_swarm_ui }.asUiText(), onClick = { processIntent(SettingsIntent.NavigateConfiguration) }, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt index a367ced4..e64c1980 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt @@ -12,6 +12,8 @@ sealed interface ServerSetupIntent : MviIntent { data class UpdateServerUrl(val url: String) : ServerSetupIntent + data class UpdateSwarmUiUrl(val url: String) : ServerSetupIntent + data class UpdateAuthType(val type: ServerSetupState.AuthType) : ServerSetupIntent data class UpdateLogin(val login: String) : ServerSetupIntent diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 58c46cc4..4fab4c96 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -21,6 +21,7 @@ data class ServerSetupState( val allowedModes: List = ServerSource.entries, val screenModal: Modal = Modal.None, val serverUrl: String = "", + val swarmUiUrl: String = "", val hordeApiKey: String = "", val huggingFaceApiKey: String = "", val openAiApiKey: String = "", @@ -36,6 +37,7 @@ data class ServerSetupState( val localCustomModel: Boolean = false, val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, + val swarmUiUrlValidationError: UiText? = null, val loginValidationError: UiText? = null, val passwordValidationError: UiText? = null, val hordeApiKeyValidationError: UiText? = null, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 9d74b8d5..59e5d6e0 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -50,6 +50,18 @@ class ServerSetupViewModel( showBackNavArrow = launchSource == ServerSetupLaunchSource.SETTINGS, ) + private val credentials: AuthorizationCredentials + get() = when (currentState.mode) { + ServerSource.AUTOMATIC1111 -> { + if (!currentState.demoMode) currentState.credentialsDomain() + else AuthorizationCredentials.None + } + + ServerSource.SWARM_UI -> currentState.credentialsDomain() + + else -> AuthorizationCredentials.None + } + private var downloadDisposable: Disposable? = null init { @@ -73,6 +85,7 @@ class ServerSetupViewModel( mode = configuration.source, demoMode = configuration.demoMode, serverUrl = configuration.serverUrl, + swarmUiUrl = configuration.swarmUiUrl, authType = configuration.authType, ) .withCredentials(configuration.authCredentials) @@ -181,6 +194,10 @@ class ServerSetupViewModel( it.copy(serverUrl = intent.url, serverUrlValidationError = null) } + is ServerSetupIntent.UpdateSwarmUiUrl -> updateState { + it.copy(swarmUiUrl = intent.url, swarmUiUrlValidationError = null) + } + is ServerSetupIntent.LaunchUrl -> { emitEffect(ServerSetupEffect.LaunchUrl(intent.url)) } @@ -219,6 +236,7 @@ class ServerSetupViewModel( ServerSource.HUGGING_FACE -> connectToHuggingFace() ServerSource.OPEN_AI -> connectToOpenAi() ServerSource.STABILITY_AI -> connectToStabilityAi() + ServerSource.SWARM_UI -> connectToSwarmUi() } .doOnSubscribe { setScreenModal(Modal.Communicating(canCancel = false)) } .subscribeOnMainThread(schedulersProvider) @@ -236,34 +254,11 @@ class ServerSetupViewModel( private fun validate(): Boolean = when (currentState.mode) { ServerSource.AUTOMATIC1111 -> { if (currentState.demoMode) true - else { - val serverUrlValidation = urlValidator(currentState.serverUrl) - var isValid = serverUrlValidation.isValid - updateState { state -> - var newState = state.copy( - serverUrlValidationError = serverUrlValidation.mapToUi() - ) - if (currentState.authType == ServerSetupState.AuthType.HTTP_BASIC) { - val loginValidation = stringValidator(currentState.login) - val passwordValidation = stringValidator(currentState.password) - newState = newState.copy( - loginValidationError = loginValidation.mapToUi(), - passwordValidationError = passwordValidation.mapToUi() - ) - isValid = isValid && loginValidation.isValid && passwordValidation.isValid - } - if (serverUrlValidation.validationError is UrlValidator.Error.Localhost - && newState.loginValidationError == null - && newState.passwordValidationError == null - ) { - newState = newState.copy(screenModal = Modal.ConnectLocalHost) - } - newState - } - isValid - } + else validateServerUrlAndCredentials(currentState.serverUrl) } + ServerSource.SWARM_UI -> validateServerUrlAndCredentials(currentState.swarmUiUrl) + ServerSource.HORDE -> { if (currentState.hordeDefaultApiKey) true else { @@ -304,37 +299,70 @@ class ServerSetupViewModel( } } + private fun validateServerUrlAndCredentials(url: String): Boolean { + val serverUrlValidation = urlValidator(url) + var isValid = serverUrlValidation.isValid + updateState { state -> + var newState = state.copy( + serverUrlValidationError = if (state.mode == ServerSource.AUTOMATIC1111) { + serverUrlValidation.mapToUi() + } else { + state.serverUrlValidationError + }, + swarmUiUrlValidationError = if (state.mode == ServerSource.SWARM_UI) { + serverUrlValidation.mapToUi() + } else { + state.swarmUiUrlValidationError + }, + ) + if (currentState.authType == ServerSetupState.AuthType.HTTP_BASIC) { + val loginValidation = stringValidator(currentState.login) + val passwordValidation = stringValidator(currentState.password) + newState = newState.copy( + loginValidationError = loginValidation.mapToUi(), + passwordValidationError = passwordValidation.mapToUi() + ) + isValid = isValid && loginValidation.isValid && passwordValidation.isValid + } + if (serverUrlValidation.validationError is UrlValidator.Error.Localhost + && newState.loginValidationError == null + && newState.passwordValidationError == null + ) { + newState = newState.copy(screenModal = Modal.ConnectLocalHost) + } + newState + } + return isValid + } + private fun connectToAutomaticInstance(): Single> { val demoMode = currentState.demoMode val connectUrl = if (demoMode) currentState.demoModeUrl else currentState.serverUrl - val credentials = when (currentState.mode) { - ServerSource.AUTOMATIC1111 -> { - if (!demoMode) currentState.credentialsDomain() - else AuthorizationCredentials.None - } - - else -> AuthorizationCredentials.None - } return setupConnectionInterActor.connectToA1111( - connectUrl, - demoMode, - credentials, + url = connectUrl, + isDemo = demoMode, + credentials = credentials, ) } + private fun connectToSwarmUi() = setupConnectionInterActor.connectToSwarmUi( + url = currentState.swarmUiUrl, + credentials = credentials, + ) + private fun connectToHuggingFace() = with(currentState) { setupConnectionInterActor.connectToHuggingFace( - huggingFaceApiKey, - huggingFaceModel, + apiKey = huggingFaceApiKey, + model = huggingFaceModel, ) } private fun connectToOpenAi() = setupConnectionInterActor.connectToOpenAi( - currentState.openAiApiKey, + apiKey = currentState.openAiApiKey, ) private fun connectToStabilityAi() = setupConnectionInterActor.connectToStabilityAi( - currentState.stabilityAiApiKey, + apiKey = currentState.stabilityAiApiKey, ) private fun connectToHorde(): Single> { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index 0bf7a40b..0ace2923 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -65,7 +65,8 @@ fun ConfigurationModeButton( .size(42.dp) .padding(top = 8.dp, bottom = 8.dp), imageVector = when (mode) { - ServerSource.AUTOMATIC1111 -> Icons.Default.Computer + ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI -> Icons.Default.Computer ServerSource.HORDE, ServerSource.OPEN_AI, ServerSource.STABILITY_AI, @@ -86,6 +87,7 @@ fun ConfigurationModeButton( ServerSource.HUGGING_FACE -> R.string.srv_type_hugging_face ServerSource.OPEN_AI -> R.string.srv_type_open_ai ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai + ServerSource.SWARM_UI -> R.string.srv_type_swarm_ui }), textAlign = TextAlign.Center, style = MaterialTheme.typography.bodyLarge, @@ -98,6 +100,7 @@ fun ConfigurationModeButton( ServerSource.OPEN_AI -> R.string.hint_open_ai_sub_title ServerSource.LOCAL -> R.string.hint_local_diffusion_sub_title ServerSource.STABILITY_AI -> R.string.hint_stability_ai_sub_title + ServerSource.SWARM_UI -> R.string.hint_swarm_ui_sub_title } descriptionId?.let { resId -> Text( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt new file mode 100644 index 00000000..9daf1976 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt @@ -0,0 +1,99 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.forms + +import androidx.compose.foundation.layout.ColumnScope +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Visibility +import androidx.compose.material.icons.filled.VisibilityOff +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.material3.TextField +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.presentation.R +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState +import com.shifthackz.aisdv1.presentation.widget.input.DropdownTextField + +@Composable +fun ColumnScope.AuthCredentialsForm( + state: ServerSetupState, + processIntent: (ServerSetupIntent) -> Unit, + modifier: Modifier, +) { + DropdownTextField( + modifier = modifier, + label = R.string.auth_title.asUiText(), + items = ServerSetupState.AuthType.entries, + value = state.authType, + onItemSelected = { + processIntent(ServerSetupIntent.UpdateAuthType(it)) + }, + displayDelegate = { type -> + when (type) { + ServerSetupState.AuthType.ANONYMOUS -> R.string.auth_anonymous + ServerSetupState.AuthType.HTTP_BASIC -> R.string.auth_http_basic + }.asUiText() + } + ) + when (state.authType) { + ServerSetupState.AuthType.HTTP_BASIC -> { + TextField( + modifier = modifier, + value = state.login, + onValueChange = { + processIntent(ServerSetupIntent.UpdateLogin(it)) + }, + label = { Text(stringResource(id = R.string.hint_login)) }, + isError = state.loginValidationError != null, + supportingText = state.loginValidationError?.let { + { Text(it.asString(), color = MaterialTheme.colorScheme.error) } + }, + maxLines = 1, + ) + TextField( + modifier = modifier, + value = state.password, + onValueChange = { + processIntent(ServerSetupIntent.UpdatePassword(it)) + }, + label = { Text(stringResource(id = R.string.hint_password)) }, + isError = state.passwordValidationError != null, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Password), + visualTransformation = if (state.passwordVisible) { + VisualTransformation.None + } else { + PasswordVisualTransformation() + }, + supportingText = state.passwordValidationError?.let { + { Text(it.asString(), color = MaterialTheme.colorScheme.error) } + }, + trailingIcon = { + val image = if (state.passwordVisible) Icons.Filled.Visibility + else Icons.Filled.VisibilityOff + val description = if (state.passwordVisible) "Hide password" else "Show password" + IconButton( + onClick = { + processIntent( + ServerSetupIntent.UpdatePasswordVisibility( + state.passwordVisible, + ), + ) + }, + content = { Icon(image, description) }, + ) + }, + maxLines = 1, + ) + } + else -> Unit + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt index 2614a70e..8211b076 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt @@ -3,14 +3,9 @@ package com.shifthackz.aisdv1.presentation.screen.setup.forms import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.filled.Help import androidx.compose.material.icons.filled.DeveloperMode -import androidx.compose.material.icons.filled.Visibility -import androidx.compose.material.icons.filled.VisibilityOff -import androidx.compose.material3.Icon -import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Switch import androidx.compose.material3.Text @@ -19,9 +14,6 @@ import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.font.FontWeight -import androidx.compose.ui.text.input.KeyboardType -import androidx.compose.ui.text.input.PasswordVisualTransformation -import androidx.compose.ui.text.input.VisualTransformation import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.core.model.asString @@ -29,7 +21,6 @@ import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState -import com.shifthackz.aisdv1.presentation.widget.input.DropdownTextField import com.shifthackz.aisdv1.presentation.widget.item.SettingsItem @Composable @@ -69,73 +60,11 @@ fun Automatic1111Form( maxLines = 1, ) if (!state.demoMode) { - DropdownTextField( + AuthCredentialsForm( + state = state, + processIntent = processIntent, modifier = fieldModifier, - label = R.string.auth_title.asUiText(), - items = ServerSetupState.AuthType.entries, - value = state.authType, - onItemSelected = { - processIntent(ServerSetupIntent.UpdateAuthType(it)) - }, - displayDelegate = { type -> - when (type) { - ServerSetupState.AuthType.ANONYMOUS -> R.string.auth_anonymous - ServerSetupState.AuthType.HTTP_BASIC -> R.string.auth_http_basic - }.asUiText() - } ) - when (state.authType) { - ServerSetupState.AuthType.HTTP_BASIC -> { - TextField( - modifier = fieldModifier, - value = state.login, - onValueChange = { - processIntent(ServerSetupIntent.UpdateLogin(it)) - }, - label = { Text(stringResource(id = R.string.hint_login)) }, - isError = state.loginValidationError != null, - supportingText = state.loginValidationError?.let { - { Text(it.asString(), color = MaterialTheme.colorScheme.error) } - }, - maxLines = 1, - ) - TextField( - modifier = fieldModifier, - value = state.password, - onValueChange = { - processIntent(ServerSetupIntent.UpdatePassword(it)) - }, - label = { Text(stringResource(id = R.string.hint_password)) }, - isError = state.passwordValidationError != null, - keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Password), - visualTransformation = if (state.passwordVisible) { - VisualTransformation.None - } else { - PasswordVisualTransformation() - }, - supportingText = state.passwordValidationError?.let { - { Text(it.asString(), color = MaterialTheme.colorScheme.error) } - }, - trailingIcon = { - val image = if (state.passwordVisible) Icons.Filled.Visibility - else Icons.Filled.VisibilityOff - val description = if (state.passwordVisible) "Hide password" else "Show password" - IconButton( - onClick = { - processIntent( - ServerSetupIntent.UpdatePasswordVisibility( - state.passwordVisible, - ), - ) - }, - content = { Icon(image, description) }, - ) - }, - maxLines = 1, - ) - } - else -> Unit - } } SettingsItem( modifier = Modifier diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt new file mode 100644 index 00000000..e10fe353 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt @@ -0,0 +1,60 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.forms + +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.material3.TextField +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.presentation.R +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState + +@Composable +fun SwarmUiForm( + modifier: Modifier = Modifier, + state: ServerSetupState, + processIntent: (ServerSetupIntent) -> Unit, +) { + Column( + modifier = modifier + .padding(horizontal = 16.dp), + ) { + Text( + modifier = Modifier + .fillMaxWidth() + .padding(top = 32.dp, bottom = 8.dp), + text = stringResource(id = R.string.hint_swarm_ui_title), + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + fontWeight = FontWeight.Bold, + ) + val fieldModifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp) + TextField( + modifier = fieldModifier, + value = state.swarmUiUrl, + onValueChange = { + processIntent(ServerSetupIntent.UpdateSwarmUiUrl(it)) + }, + label = { Text(stringResource(id = R.string.hint_server_url)) }, + isError = state.swarmUiUrlValidationError != null, + supportingText = state.swarmUiUrlValidationError + ?.let { { Text(it.asString(), color = MaterialTheme.colorScheme.error) } }, + maxLines = 1, + ) + AuthCredentialsForm( + state = state, + processIntent = processIntent, + modifier = fieldModifier, + ) + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt index a817cb54..c2f043dc 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt @@ -11,6 +11,7 @@ fun FeatureTag.mapToUi(): String { id = when (this) { FeatureTag.Txt2Img -> R.string.home_tab_txt_to_img FeatureTag.Img2Img -> R.string.home_tab_img_to_img + FeatureTag.OwnServer -> R.string.hint_own_server FeatureTag.Lora -> R.string.title_lora FeatureTag.TextualInversion -> R.string.title_txt_inversion FeatureTag.HyperNetworks -> R.string.title_hyper_net diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index f78f93d6..4d53ca89 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -12,6 +12,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.forms.HuggingFaceForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.LocalDiffusionForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.OpenAiForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.StabilityAiForm +import com.shifthackz.aisdv1.presentation.screen.setup.forms.SwarmUiForm @Composable fun ConfigurationStep( @@ -52,6 +53,11 @@ fun ConfigurationStep( state = state, processIntent = processIntent, ) + + ServerSource.SWARM_UI -> SwarmUiForm( + state = state, + processIntent = processIntent, + ) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index cb007e51..c6ec9334 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -57,6 +57,7 @@ fun EngineSelectionComponent( ServerSource.HORDE -> Unit ServerSource.OPEN_AI -> Unit + ServerSource.SWARM_UI -> Unit } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index c0fc3da6..6fea88c5 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -286,6 +286,7 @@ fun GenerationInputForm( ) } + ServerSource.SWARM_UI -> Unit } } diff --git a/presentation/src/main/res/values/strings.xml b/presentation/src/main/res/values/strings.xml index c3b31a7b..c4a9210d 100755 --- a/presentation/src/main/res/values/strings.xml +++ b/presentation/src/main/res/values/strings.xml @@ -69,6 +69,7 @@ HuggingFace Open AI Stability AI + Swarm UI Prompt Negative prompt @@ -77,6 +78,7 @@ Batch Multiple Models Offline generation + Own Server CFG Scale: %1$s Sampling method Style preset @@ -136,6 +138,9 @@ About Stability AI Stability AI Engine + Connect to Swarm UI + Swarm UI ... + Local Diffusion This configuration allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). From 052f3f5d2af54cfa6b481d67e277b736f2091d74 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sun, 4 Aug 2024 20:04:32 +0300 Subject: [PATCH 02/11] Swarm UI models implementation --- .../aisdv1/core/common/model/Hexagonal.kt | 15 ++++++ .../aisdv1/core/common/model/Quadruple.kt | 1 - .../aisdv1/data/di/LocalDataSourceModule.kt | 3 ++ .../aisdv1/data/di/RemoteDataSourceModule.kt | 6 +++ .../aisdv1/data/di/RepositoryModule.kt | 3 ++ .../local/SwarmUiModelsLocalDataSource.kt | 23 +++++++++ .../SwarmUiImageToImagePayloadMappers.kt | 27 ++++++++++ .../data/mappers/SwarmUiModelsMappers.kt | 47 +++++++++++++++++ .../SwarmUiTextToImagePayloadMappers.kt | 14 ++++- .../data/preference/PreferenceManagerImpl.kt | 9 ++++ .../data/preference/SessionPreferenceImpl.kt | 1 + .../SwarmUiGenerationRemoteDataSource.kt | 51 +++++++++++-------- .../remote/SwarmUiModelsRemoteDataSource.kt | 29 +++++++++++ .../remote/SwarmUiSessionDataSourceImpl.kt | 38 ++++++++++++++ .../SwarmUiGenerationRepositoryImpl.kt | 47 ++++++++++------- .../repository/SwarmUiModelsRepositoryImpl.kt | 26 ++++++++++ .../datasource/SwarmUiGenerationDataSource.kt | 17 +++++-- .../datasource/SwarmUiModelsDataSource.kt | 17 +++++++ .../datasource/SwarmUiSessionDataSource.kt | 7 +++ .../aisdv1/domain/di/DomainModule.kt | 3 ++ .../aisdv1/domain/entity/SwarmUiModel.kt | 7 +++ .../domain/preference/PreferenceManager.kt | 1 + .../repository/SwarmUiGenerationRepository.kt | 2 + .../repository/SwarmUiModelsRepository.kt | 11 ++++ .../generation/ImageToImageUseCaseImpl.kt | 3 ++ ...etchAndGetStabilityAiEnginesUseCaseImpl.kt | 3 +- .../FetchAndGetSwarmUiModelsUseCase.kt | 8 +++ .../FetchAndGetSwarmUiModelsUseCaseImpl.kt | 23 +++++++++ .../aisdv1/network/api/swarmui/SwarmUiApi.kt | 25 +++++++-- .../network/api/swarmui/SwarmUiApiImpl.kt | 20 ++++++-- .../aisdv1/network/model/SwarmUiModelRaw.kt | 12 +++++ .../request/SwarmUiGenerationRequest.kt | 21 ++++++++ .../network/request/SwarmUiModelsRequest.kt | 12 +++++ .../network/response/SwarmUiModelsResponse.kt | 9 ++++ .../screen/img2img/ImageToImageScreen.kt | 1 + .../screen/settings/SettingsState.kt | 2 +- .../widget/engine/EngineSelectionComponent.kt | 10 +++- .../widget/engine/EngineSelectionState.kt | 2 + .../widget/engine/EngineSelectionViewModel.kt | 18 +++++-- .../widget/input/GenerationInputForm.kt | 15 ++++-- .../1.json | 42 ++++++++++++++- .../aisdv1/storage/db/cache/CacheDatabase.kt | 4 ++ .../db/cache/contract/SwarmUiModelContract.kt | 10 ++++ .../storage/db/cache/dao/SwarmUiModelDao.kt | 23 +++++++++ .../db/cache/entity/SwarmUiModelEntity.kt | 19 +++++++ .../aisdv1/storage/di/DatabaseModule.kt | 1 + 46 files changed, 619 insertions(+), 69 deletions(-) create mode 100644 core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt create mode 100644 storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt create mode 100644 storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt create mode 100644 storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt new file mode 100644 index 00000000..4bc94dde --- /dev/null +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt @@ -0,0 +1,15 @@ +package com.shifthackz.aisdv1.core.common.model + +import java.io.Serializable + +data class Hexagonal( + val first: A, + val second: B, + val third: C, + val fourth: D, + val fifth: E, + val sixth: F, +) : Serializable { + + override fun toString(): String = "($first, $second, $third, $fourth, $fifth, $sixth)" +} diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt index 38499203..7234f7fa 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt @@ -2,7 +2,6 @@ package com.shifthackz.aisdv1.core.common.model import java.io.Serializable - data class Quadruple( val first: A, val second: B, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt index 97410914..4d4a8012 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt @@ -12,6 +12,7 @@ import com.shifthackz.aisdv1.data.local.StableDiffusionHyperNetworksLocalDataSou import com.shifthackz.aisdv1.data.local.StableDiffusionLorasLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionModelsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionSamplersLocalDataSource +import com.shifthackz.aisdv1.data.local.SwarmUiModelsLocalDataSource import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource @@ -22,6 +23,7 @@ import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataS import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource import com.shifthackz.aisdv1.domain.gateway.DatabaseClearGateway import org.koin.android.ext.koin.androidContext import org.koin.core.module.dsl.factoryOf @@ -38,6 +40,7 @@ val localDataSourceModule = module { factoryOf(::StableDiffusionLorasLocalDataSource) bind StableDiffusionLorasDataSource.Local::class factoryOf(::StableDiffusionHyperNetworksLocalDataSource) bind StableDiffusionHyperNetworksDataSource.Local::class factoryOf(::StableDiffusionEmbeddingsLocalDataSource) bind StableDiffusionEmbeddingsDataSource.Local::class + factoryOf(::SwarmUiModelsLocalDataSource) bind SwarmUiModelsDataSource.Local::class factoryOf(::ServerConfigurationLocalDataSource) bind ServerConfigurationDataSource.Local::class factoryOf(::GenerationResultLocalDataSource) bind GenerationResultDataSource.Local::class factoryOf(::DownloadableModelLocalDataSource) bind DownloadableModelDataSource.Local::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt index d3eb6205..3a834ea8 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt @@ -21,6 +21,8 @@ import com.shifthackz.aisdv1.data.remote.StableDiffusionLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionSamplersRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiGenerationRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiModelsRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiSessionDataSourceImpl import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource @@ -38,6 +40,8 @@ import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.gateway.ServerConnectivityGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager @@ -67,7 +71,9 @@ val remoteDataSourceModule = module { factoryOf(::HordeGenerationRemoteDataSource) bind HordeGenerationDataSource.Remote::class factoryOf(::HuggingFaceGenerationRemoteDataSource) bind HuggingFaceGenerationDataSource.Remote::class factoryOf(::OpenAiGenerationRemoteDataSource) bind OpenAiGenerationDataSource.Remote::class + factoryOf(::SwarmUiSessionDataSourceImpl) bind SwarmUiSessionDataSource::class factoryOf(::SwarmUiGenerationRemoteDataSource) bind SwarmUiGenerationDataSource.Remote::class + factoryOf(::SwarmUiModelsRemoteDataSource) bind SwarmUiModelsDataSource.Remote::class factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index 7e61b96e..6e665fb3 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -21,6 +21,7 @@ import com.shifthackz.aisdv1.data.repository.StableDiffusionLorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionSamplersRepositoryImpl import com.shifthackz.aisdv1.data.repository.SwarmUiGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.SwarmUiModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.TemporaryGenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.WakeLockRepositoryImpl import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository @@ -42,6 +43,7 @@ import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiModelsRepository import com.shifthackz.aisdv1.domain.repository.TemporaryGenerationResultRepository import com.shifthackz.aisdv1.domain.repository.WakeLockRepository import org.koin.android.ext.koin.androidContext @@ -63,6 +65,7 @@ val repositoryModule = module { factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class factoryOf(::SwarmUiGenerationRepositoryImpl) bind SwarmUiGenerationRepository::class + factoryOf(::SwarmUiModelsRepositoryImpl) bind SwarmUiModelsRepository::class factoryOf(::StabilityAiGenerationRepositoryImpl) bind StabilityAiGenerationRepository::class factoryOf(::StabilityAiCreditsRepositoryImpl) bind StabilityAiCreditsRepository::class factoryOf(::StabilityAiEnginesRepositoryImpl) bind StabilityAiEnginesRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt new file mode 100644 index 00000000..003c61a5 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity +import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiModelsLocalDataSource( + private val dao: SwarmUiModelDao, +) : SwarmUiModelsDataSource.Local { + + override fun getModels(): Single> = dao + .queryAll() + .map(List::mapEntityToDomain) + + override fun insertModels(models: List): Completable = models + .mapDomainToEntity() + .let(dao::insertList) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt new file mode 100644 index 00000000..8514bef4 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt @@ -0,0 +1,27 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest + +fun ImageToImagePayload.mapToSwarmUiRequest( + sessionId: String, + swarmUiModel: String, +): SwarmUiGenerationRequest = with(this) { + SwarmUiGenerationRequest( + sessionId = sessionId, + model = swarmUiModel, + initImage = "data:image/png;base64,${base64Image.trim('\n').trim('\u003d')}", +// initImage = base64Image, +// initImage = base64Image, + images = 1, + prompt = prompt, + negativePrompt = negativePrompt, + width = width, + height = height, + seed = seed.trim().ifEmpty { null }, + variationSeed = subSeed.trim().ifEmpty { null }, + variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), + cfgScale = cfgScale, + steps = samplingSteps, + ) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt new file mode 100644 index 00000000..1926fd3f --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity + +//region RAW --> DOMAIN +fun SwarmUiModelsResponse.mapRawToDomain(): List = with(this) { + this.files?.mapRawToDomain() ?: emptyList() +} + +fun List.mapRawToDomain(): List = map(SwarmUiModelRaw::mapRawToDomain) + +fun SwarmUiModelRaw.mapRawToDomain(): SwarmUiModel = with(this) { + SwarmUiModel( + name = name ?: "", + title = title ?: "", + author = author ?: "", + ) +} +//endregion + +//region DOMAIN --> ENTITY +fun List.mapDomainToEntity(): List = map(SwarmUiModel::mapDomainToEntity) + +fun SwarmUiModel.mapDomainToEntity(): SwarmUiModelEntity = with(this) { + SwarmUiModelEntity( + id = "${name}_${title}", + name = name, + title = title, + author = author, + ) +} +//endregion + +//region ENTITY --> DOMAIN +fun List.mapEntityToDomain(): List = map(SwarmUiModelEntity::mapEntityToDomain) + +fun SwarmUiModelEntity.mapEntityToDomain(): SwarmUiModel = with(this) { + SwarmUiModel( + name = name, + title = title, + author = author, + ) +} +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt index 2998c824..f3ef2863 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt @@ -3,14 +3,24 @@ package com.shifthackz.aisdv1.data.mappers import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest -fun TextToImagePayload.mapToSwarmUiRequest(sessionId: String): SwarmUiGenerationRequest = with(this) { +fun TextToImagePayload.mapToSwarmUiRequest( + sessionId: String, + swarmUiModel: String, +): SwarmUiGenerationRequest = with(this) { SwarmUiGenerationRequest( sessionId = sessionId, - model = "OfficialStableDiffusion/sd_xl_base_1.0", + model = swarmUiModel, + initImage = null, images = 1, prompt = prompt, + negativePrompt = negativePrompt, width = width, height = height, + seed = seed.trim().ifEmpty { null }, + variationSeed = subSeed.trim().ifEmpty { null }, + variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), + cfgScale = cfgScale, + steps = samplingSteps, ) } // diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index e250e4d6..c1651153 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -34,6 +34,14 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } + override var swarmModel: String + get() = preferences.getString(KEY_SWARM_MODEL, "") ?: "" + set(value) = preferences + .edit() + .putString(KEY_SWARM_MODEL, value) + .apply() + .also { onPreferencesChanged() } + override var demoMode: Boolean get() = preferences.getBoolean(KEY_DEMO_MODE, false) set(value) = preferences.edit() @@ -223,6 +231,7 @@ class PreferenceManagerImpl( companion object { const val KEY_SERVER_URL = "key_server_url" const val KEY_SWARM_SERVER_URL = "key_swarm_server_url" + const val KEY_SWARM_MODEL = "key_swarm_model" const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" const val KEY_AI_AUTO_SAVE = "key_ai_auto_save" diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt index b076d72a..162016fc 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt @@ -12,6 +12,7 @@ class SessionPreferenceImpl : SessionPreference { set(value) { _coinsPerDay = value } + override var swarmUiSessionId: String get() = _swarmUiSessionId set(value) { diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt index 1a394a14..ed7fda3d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt @@ -6,10 +6,12 @@ import com.shifthackz.aisdv1.data.mappers.mapCloudToAiGenResult import com.shifthackz.aisdv1.data.mappers.mapToSwarmUiRequest import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi -import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_SESSION -import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_TXT_TO_IMG +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_GENERATE +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest import io.reactivex.rxjava3.core.Single class SwarmUiGenerationRemoteDataSource( @@ -18,18 +20,33 @@ class SwarmUiGenerationRemoteDataSource( private val converter: BitmapToBase64Converter, ) : SwarmUiGenerationDataSource.Remote { - override fun getNewSession() = PATH_SESSION - .let(serverUrlProvider::invoke) - .flatMap(::getSessionForUrl) - - override fun getNewSession(url: String) = - getSessionForUrl("${url.fixUrlSlashes()}/$PATH_SESSION") - override fun textToImage( sessionId: String, - payload: TextToImagePayload, - ) = serverUrlProvider(PATH_TXT_TO_IMG) - .flatMap { url -> api.textToImage(url, payload.mapToSwarmUiRequest(sessionId)) } + model: String, + payload: TextToImagePayload + ): Single = + generate( + payload = payload, + request = payload.mapToSwarmUiRequest(sessionId, model), + ) + .map(Pair::mapCloudToAiGenResult) + + override fun imageToImage( + sessionId: String, + model: String, + payload: ImageToImagePayload, + ): Single = + generate( + payload = payload, + request = payload.mapToSwarmUiRequest(sessionId, model), + ) + .map(Pair::mapCloudToAiGenResult) + + private fun generate( + payload: T, + request: SwarmUiGenerationRequest, + ): Single> = serverUrlProvider(PATH_GENERATE) + .flatMap { url -> api.generate(url, request) } .flatMap { response -> serverUrlProvider("").map { url -> response to url } } @@ -44,14 +61,4 @@ class SwarmUiGenerationRemoteDataSource( .flatMap(converter::invoke) .map(BitmapToBase64Converter.Output::base64ImageString) .map { base64 -> payload to base64 } - .map(Pair::mapCloudToAiGenResult) - - private fun getSessionForUrl(url: String) = api - .getNewSession(url) - .flatMap { response -> - response.sessionId - ?.takeIf(String::isNotBlank) - ?.let { Single.just(it) } - ?: Single.error(IllegalStateException("Bad session ID.")) - } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt new file mode 100644 index 00000000..50e0605d --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_MODELS +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiModelsRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, +) : SwarmUiModelsDataSource.Remote { + + override fun fetchSwarmModels(sessionId: String): Single> = PATH_MODELS + .let(serverUrlProvider::invoke) + .flatMap { url -> + val request = SwarmUiModelsRequest( + sessionId = sessionId, + path = "", + depth = 3, + ) + api.fetchModels(url, request) + } + .map(SwarmUiModelsResponse::mapRawToDomain) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt new file mode 100644 index 00000000..5e66e0ab --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt @@ -0,0 +1,38 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.preference.SessionPreference +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_SESSION +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiSessionDataSourceImpl( + private val api: SwarmUiApi, + private val sessionPreference: SessionPreference, + private val serverUrlProvider: ServerUrlProvider, +) : SwarmUiSessionDataSource { + + override fun getSessionId(connectUrl: String?): Single = + if (sessionPreference.swarmUiSessionId.isBlank()) { + val chain = connectUrl + ?.let { url -> "$url/$PATH_SESSION".fixUrlSlashes() } + ?.let(api::getNewSession) + ?: serverUrlProvider(PATH_SESSION).flatMap(api::getNewSession) + + chain + .flatMap { response -> + response.sessionId + ?.takeIf(String::isNotBlank) + ?.let { Single.just(it) } + ?: Single.error(IllegalStateException("Bad session ID.")) + } + .map { sessionId -> + sessionPreference.swarmUiSessionId = sessionId + sessionId + } + } else { + Single.just(sessionPreference.swarmUiSessionId) + } +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt index 8d4a2858..e4919e26 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt @@ -4,20 +4,23 @@ import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter import com.shifthackz.aisdv1.data.core.CoreGenerationRepository import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.domain.preference.SessionPreference import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository +import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single internal class SwarmUiGenerationRepositoryImpl( mediaStoreGateway: MediaStoreGateway, base64ToBitmapConverter: Base64ToBitmapConverter, localDataSource: GenerationResultDataSource.Local, - private val remoteDataSource: SwarmUiGenerationDataSource.Remote, private val preferenceManager: PreferenceManager, - private val sessionPreference: SessionPreference, + private val session: SwarmUiSessionDataSource, + private val remoteDataSource: SwarmUiGenerationDataSource.Remote, ) : CoreGenerationRepository( mediaStoreGateway, base64ToBitmapConverter, @@ -25,27 +28,33 @@ internal class SwarmUiGenerationRepositoryImpl( preferenceManager, ), SwarmUiGenerationRepository { - override fun checkApiAvailability() = obtainSessionId() + override fun checkApiAvailability(): Completable = session + .getSessionId() .ignoreElement() - override fun checkApiAvailability(url: String) = obtainSessionId(url) + override fun checkApiAvailability(url: String): Completable = session + .getSessionId(url) .ignoreElement() - override fun generateFromText(payload: TextToImagePayload) = obtainSessionId() - .flatMap { sessionId -> remoteDataSource.textToImage(sessionId, payload) } + override fun generateFromText(payload: TextToImagePayload): Single = session + .getSessionId() + .flatMap { sessionId -> + remoteDataSource.textToImage( + sessionId = sessionId, + model = preferenceManager.swarmModel, + payload = payload, + ) + } .flatMap(::insertGenerationResult) - private fun obtainSessionId(connectUrl: String? = null) = - if (sessionPreference.swarmUiSessionId.isBlank()) { - val chain = connectUrl - ?.let(remoteDataSource::getNewSession) - ?: remoteDataSource.getNewSession() - - chain.map { sessionId -> - sessionPreference.swarmUiSessionId = sessionId - sessionId - } - } else { - Single.just(sessionPreference.swarmUiSessionId) + override fun generateFromImage(payload: ImageToImagePayload): Single = session + .getSessionId() + .flatMap { sessionId -> + remoteDataSource.imageToImage( + sessionId = sessionId, + model = preferenceManager.swarmModel, + payload = payload, + ) } + .flatMap(::insertGenerationResult) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt new file mode 100644 index 00000000..cda65ffb --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.domain.repository.SwarmUiModelsRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiModelsRepositoryImpl( + private val session: SwarmUiSessionDataSource, + private val rds: SwarmUiModelsDataSource.Remote, + private val lds: SwarmUiModelsDataSource.Local, +) : SwarmUiModelsRepository { + + override fun fetchModels(): Completable = session + .getSessionId() + .flatMap(rds::fetchSwarmModels) + .flatMapCompletable(lds::insertModels) + + override fun fetchAndGetModels(): Single> = fetchModels() + .onErrorComplete() + .andThen(getModels()) + + override fun getModels(): Single> = lds.getModels() +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt index bafca9ab..88c388a0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt @@ -1,14 +1,23 @@ package com.shifthackz.aisdv1.domain.datasource import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import io.reactivex.rxjava3.core.Single sealed interface SwarmUiGenerationDataSource { - interface Remote { - fun getNewSession(): Single - fun getNewSession(url: String): Single - fun textToImage(sessionId: String, payload: TextToImagePayload): Single + interface Remote : SwarmUiGenerationDataSource { + fun textToImage( + sessionId: String, + model: String, + payload: TextToImagePayload, + ): Single + + fun imageToImage( + sessionId: String, + model: String, + payload: ImageToImagePayload, + ): Single } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt new file mode 100644 index 00000000..f47f21c7 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt @@ -0,0 +1,17 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface SwarmUiModelsDataSource { + + interface Remote : SwarmUiModelsDataSource { + fun fetchSwarmModels(sessionId: String): Single> + } + + interface Local : SwarmUiModelsDataSource { + fun getModels(): Single> + fun insertModels(models: List): Completable + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt new file mode 100644 index 00000000..0324599a --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.datasource + +import io.reactivex.rxjava3.core.Single + +interface SwarmUiSessionDataSource { + fun getSessionId(connectUrl: String? = null): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index e2823a4d..7170f723 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -100,6 +100,8 @@ import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEn import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.stabilityai.ObserveStabilityAiCreditsUseCase import com.shifthackz.aisdv1.domain.usecase.stabilityai.ObserveStabilityAiCreditsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCase import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.wakelock.ReleaseWakeLockUseCase @@ -114,6 +116,7 @@ internal val useCasesModule = module { factoryOf(::PingStableDiffusionServiceUseCaseImpl) bind PingStableDiffusionServiceUseCase::class factoryOf(::ClearAppCacheUseCaseImpl) bind ClearAppCacheUseCase::class factoryOf(::DataPreLoaderUseCaseImpl) bind DataPreLoaderUseCase::class + factoryOf(::FetchAndGetSwarmUiModelsUseCaseImpl) bind FetchAndGetSwarmUiModelsUseCase::class factoryOf(::GetStableDiffusionModelsUseCaseImpl) bind GetStableDiffusionModelsUseCase::class factoryOf(::SelectStableDiffusionModelUseCaseImpl) bind SelectStableDiffusionModelUseCase::class factoryOf(::GetGenerationResultPagedUseCaseImpl) bind GetGenerationResultPagedUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt new file mode 100644 index 00000000..e79e2673 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.entity + +class SwarmUiModel( + val name: String, + val title: String, + val author: String, +) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index 06311558..8f057184 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -7,6 +7,7 @@ import io.reactivex.rxjava3.core.Flowable interface PreferenceManager { var serverUrl: String var swarmServerUrl: String + var swarmModel: String var demoMode: Boolean var monitorConnectivity: Boolean var autoSaveAiResults: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt index bf4bb83e..824ccb34 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.domain.repository import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single @@ -9,4 +10,5 @@ interface SwarmUiGenerationRepository { fun checkApiAvailability(): Completable fun checkApiAvailability(url: String): Completable fun generateFromText(payload: TextToImagePayload): Single + fun generateFromImage(payload: ImageToImagePayload): Single } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt new file mode 100644 index 00000000..5e97c9ab --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface SwarmUiModelsRepository { + fun fetchModels(): Completable + fun fetchAndGetModels(): Single> + fun getModels(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt index 2336f99e..fcade672 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt @@ -7,11 +7,13 @@ import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single internal class ImageToImageUseCaseImpl( private val stableDiffusionGenerationRepository: StableDiffusionGenerationRepository, + private val swarmUiGenerationRepository: SwarmUiGenerationRepository, private val hordeGenerationRepository: HordeGenerationRepository, private val huggingFaceGenerationRepository: HuggingFaceGenerationRepository, private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, @@ -25,6 +27,7 @@ internal class ImageToImageUseCaseImpl( private fun generate(payload: ImageToImagePayload) = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromImage(payload) + ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromImage(payload) ServerSource.HORDE -> hordeGenerationRepository.generateFromImage(payload) ServerSource.HUGGING_FACE -> huggingFaceGenerationRepository.generateFromImage(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromImage(payload) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt index 93670eaa..5a17f6f6 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.domain.usecase.stabilityai +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.StabilityAiEnginesRepository import io.reactivex.rxjava3.core.Single @@ -12,7 +13,7 @@ internal class FetchAndGetStabilityAiEnginesUseCaseImpl( override fun invoke() = repository .fetchAndGet() .flatMap { engines -> - if (!engines.map { it.id }.contains(preferenceManager.stabilityAiEngineId)) { + if (!engines.map(StabilityAiEngine::id).contains(preferenceManager.stabilityAiEngineId)) { preferenceManager.stabilityAiEngineId = engines.firstOrNull()?.id ?: "" } Single.just(engines) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt new file mode 100644 index 00000000..c672c93a --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.swarmmodel + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import io.reactivex.rxjava3.core.Single + +interface FetchAndGetSwarmUiModelsUseCase { + operator fun invoke(): Single>> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt new file mode 100644 index 00000000..d32191d6 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.domain.usecase.swarmmodel + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.SwarmUiModelsRepository +import io.reactivex.rxjava3.core.Single + +internal class FetchAndGetSwarmUiModelsUseCaseImpl( + private val preferenceManager: PreferenceManager, + private val repository: SwarmUiModelsRepository, +) : FetchAndGetSwarmUiModelsUseCase { + + override fun invoke(): Single>> = repository + .fetchAndGetModels() + .map { models -> + if (!models.map(SwarmUiModel::name).contains(preferenceManager.swarmModel)) { + preferenceManager.swarmModel = models.firstOrNull()?.name ?: "" + } + models.map { model -> + model to (preferenceManager.swarmModel == model.name) + } + } +} diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt index 15e55639..0f261edd 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt @@ -2,7 +2,9 @@ package com.shifthackz.aisdv1.network.api.swarmui import android.graphics.Bitmap import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse import io.reactivex.rxjava3.core.Single import okhttp3.ResponseBody @@ -17,24 +19,38 @@ interface SwarmUiApi { fun getNewSession(url: String): Single - fun textToImage( + fun generate( @Url url: String, @Body request: SwarmUiGenerationRequest, ): Single + fun fetchModels( + @Url url: String, + @Body request: SwarmUiModelsRequest, + ): Single + fun downloadImage(url: String): Single interface RawApi { @POST - fun getNewSession(@Url url: String, @Body map: Map): Single + fun getNewSession( + @Url url: String, + @Body map: Map, + ): Single @POST - fun textToImage( + fun generate( @Url url: String, @Body request: SwarmUiGenerationRequest, ): Single + @POST + fun fetchModels( + @Url url: String, + @Body request: SwarmUiModelsRequest, + ): Single + @Streaming @GET fun download(@Url url: String): Single> @@ -42,6 +58,7 @@ interface SwarmUiApi { companion object { const val PATH_SESSION = "API/GetNewSession" - const val PATH_TXT_TO_IMG = "API/GenerateText2Image" + const val PATH_GENERATE = "API/GenerateText2Image" + const val PATH_MODELS = "API/ListModels" } } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt index 56a6dc4b..e17a46a7 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt @@ -1,21 +1,33 @@ package com.shifthackz.aisdv1.network.api.swarmui +import android.graphics.Bitmap import android.graphics.BitmapFactory import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse import io.reactivex.rxjava3.core.Single internal class SwarmUiApiImpl( private val rawApi: SwarmUiApi.RawApi, ) : SwarmUiApi { - override fun getNewSession(url: String) = rawApi.getNewSession(url, emptyMap()) + override fun getNewSession(url: String): Single = rawApi + .getNewSession(url, emptyMap()) - override fun textToImage( + override fun generate( url: String, request: SwarmUiGenerationRequest, - ) = rawApi.textToImage(url, request) + ): Single = rawApi.generate(url, request) - override fun downloadImage(url: String) = rawApi.download(url) + override fun fetchModels( + url: String, + request: SwarmUiModelsRequest + ): Single = rawApi.fetchModels(url, request) + + override fun downloadImage(url: String): Single = rawApi + .download(url) .flatMap { response -> response.body() ?.bytes() diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt b/network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt new file mode 100644 index 00000000..47f5fc53 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.network.model + +import com.google.gson.annotations.SerializedName + +data class SwarmUiModelRaw( + @SerializedName("name") + val name: String?, + @SerializedName("title") + val title: String?, + @SerializedName("author") + val author: String?, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt index b55df94a..80a75bec 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt @@ -7,12 +7,33 @@ data class SwarmUiGenerationRequest( val sessionId: String, @SerializedName("model") val model: String, + @SerializedName("initimage") + val initImage: String?, @SerializedName("images") val images: Int, @SerializedName("prompt") val prompt: String, + @SerializedName("negativeprompt") + val negativePrompt: String, @SerializedName("width") val width: Int, @SerializedName("height") val height: Int, + @SerializedName("seed") + val seed: String?, + @SerializedName("variationseed") + val variationSeed: String?, + @SerializedName("variationseedstrength") + val variationSeedStrength: String?, + @SerializedName("cfgscale") + val cfgScale: Float?, + @SerializedName("steps") + val steps: Int, + @SerializedName("initimagecreativity") + val initimagecreativity: String = "0.6", + @SerializedName("initimageresettonorm") + val initimageresettonorm: String = "0", + @SerializedName("initimagerecompositemask") + val initimagerecompositemask: String = "0", + ) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt new file mode 100644 index 00000000..83a98aa3 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.network.request + +import com.google.gson.annotations.SerializedName + +data class SwarmUiModelsRequest( + @SerializedName("session_id") + val sessionId: String, + @SerializedName("path") + val path: String, + @SerializedName("depth") + val depth: Int, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt new file mode 100644 index 00000000..4913c5b9 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName +import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw + +data class SwarmUiModelsResponse( + @SerializedName("files") + val files: List?, +) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt index 08065943..77e61f25 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt @@ -136,6 +136,7 @@ private fun ScreenContent( content = { paddingValues -> when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HORDE, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE -> { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt index 306ca562..1239c701 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt @@ -39,7 +39,7 @@ data class SettingsState( get() = serverSource == ServerSource.AUTOMATIC1111 val showMonitorConnectionOption: Boolean - get() = serverSource == ServerSource.AUTOMATIC1111 + get() = serverSource == ServerSource.AUTOMATIC1111 || serverSource == ServerSource.SWARM_UI val showFormAdvancedOption: Boolean get() = serverSource != ServerSource.OPEN_AI diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index c6ec9334..878640b9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -27,6 +27,15 @@ fun EngineSelectionComponent( onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, ) + ServerSource.SWARM_UI -> DropdownTextField( + label = R.string.hint_sd_model.asUiText(), + loading = state.loading, + modifier = modifier, + value = state.selectedSwarmModel, + items = state.swarmModels, + onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, + ) + ServerSource.HUGGING_FACE -> DropdownTextField( label = R.string.hint_hugging_face_model.asUiText(), loading = state.loading, @@ -57,7 +66,6 @@ fun EngineSelectionComponent( ServerSource.HORDE -> Unit ServerSource.OPEN_AI -> Unit - ServerSource.SWARM_UI -> Unit } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt index 7b510934..382a299d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt @@ -11,6 +11,8 @@ data class EngineSelectionState( val mode: ServerSource = ServerSource.AUTOMATIC1111, val sdModels: List = emptyList(), val selectedSdModel: String = "", + val swarmModels: List = emptyList(), + val selectedSwarmModel: String = "", val hfModels: List = emptyList(), val selectedHfModel: String = "", val stEngines: List = emptyList(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index 33582b3b..a541754f 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.presentation.widget.engine import com.shifthackz.aisdv1.core.common.extensions.EmptyLambda import com.shifthackz.aisdv1.core.common.log.errorLog -import com.shifthackz.aisdv1.core.common.model.Quintuple +import com.shifthackz.aisdv1.core.common.model.Hexagonal import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel @@ -16,6 +16,7 @@ import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseC import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCase +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCase import com.shifthackz.android.core.mvi.EmptyEffect import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.kotlin.subscribeBy @@ -26,6 +27,7 @@ class EngineSelectionViewModel( private val getConfigurationUseCase: GetConfigurationUseCase, private val selectStableDiffusionModelUseCase: SelectStableDiffusionModelUseCase, private val getStableDiffusionModelsUseCase: GetStableDiffusionModelsUseCase, + private val fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase, observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, @@ -43,6 +45,10 @@ class EngineSelectionViewModel( .onErrorReturn { emptyList() } .toFlowable() + val swarmModels = fetchAndGetSwarmUiModelsUseCase() + .onErrorReturn { emptyList() } + .toFlowable() + val huggingFaceModels = getHuggingFaceModelsUseCase() .onErrorReturn { emptyList() } .toFlowable() @@ -58,16 +64,17 @@ class EngineSelectionViewModel( !Flowable.combineLatest( configuration, a1111Models, + swarmModels, huggingFaceModels, stabilityAiEngines, localAiModels, - ::Quintuple, + ::Hexagonal, ) .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = ::errorLog, onComplete = EmptyLambda, - onNext = { (config, sdModels, hfModels, stEngines, localModels) -> + onNext = { (config, sdModels, swarmModels, hfModels, stEngines, localModels) -> updateState { state -> state.copy( loading = false, @@ -75,6 +82,9 @@ class EngineSelectionViewModel( sdModels = sdModels.map { it.first.title }, selectedSdModel = sdModels.firstOrNull { it.second }?.first?.title ?: state.selectedSdModel, + swarmModels = swarmModels.map { it.first.name }, + selectedSwarmModel = swarmModels.firstOrNull { it.second }?.first?.name + ?: state.selectedSwarmModel, hfModels = hfModels.map { it.alias }, selectedHfModel = config.huggingFaceModel, stEngines = stEngines.map { it.id }, @@ -111,6 +121,8 @@ class EngineSelectionViewModel( } } + ServerSource.SWARM_UI -> preferenceManager.swarmModel = intent.value + ServerSource.HUGGING_FACE -> preferenceManager.huggingFaceModel = intent.value ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index 6fea88c5..f82e6ce3 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -141,6 +141,7 @@ fun GenerationInputForm( Column(modifier = modifier) { when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE, ServerSource.LOCAL -> EngineSelectionComponent( @@ -195,6 +196,7 @@ fun GenerationInputForm( // Horde does not support "negative prompt" when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HUGGING_FACE, ServerSource.STABILITY_AI, ServerSource.LOCAL -> { @@ -265,6 +267,7 @@ fun GenerationInputForm( } ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HUGGING_FACE -> { sizeTextFieldsComponent(localModifier) } @@ -285,8 +288,6 @@ fun GenerationInputForm( displayDelegate = { it.key.asUiText() }, ) } - - ServerSource.SWARM_UI -> Unit } } @@ -433,9 +434,10 @@ fun GenerationInputForm( ) } } - // Variation seed only supported for A1111 - if (state.mode == ServerSource.AUTOMATIC1111) { - TextField( + // Variation seed supported for A1111, SwarmUI + when (state.mode) { + ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI -> TextField( modifier = Modifier .fillMaxWidth() .padding(top = 8.dp), @@ -459,10 +461,13 @@ fun GenerationInputForm( } }, ) + + else -> Unit } // Sub-seed strength is not available for Local Diffusion when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HORDE -> { Text( modifier = Modifier.padding(top = 8.dp), diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json index 59134bbd..2f273525 100644 --- a/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json +++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json @@ -2,7 +2,7 @@ "formatVersion": 1, "database": { "version": 1, - "identityHash": "49878af78a080f644679bc65b796475a", + "identityHash": "5dc36f6910330fd3f65e65c4e15e56c2", "entities": [ { "tableName": "server_config", @@ -219,12 +219,50 @@ }, "indices": [], "foreignKeys": [] + }, + { + "tableName": "swarm_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `name` TEXT NOT NULL, `title` TEXT NOT NULL, `author` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "title", + "columnName": "title", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "author", + "columnName": "author", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] } ], "views": [], "setupQueries": [ "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", - "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '49878af78a080f644679bc65b796475a')" + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '5dc36f6910330fd3f65e65c4e15e56c2')" ] } } \ No newline at end of file diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt index 7db368fa..3dc34f74 100755 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt @@ -12,12 +12,14 @@ import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionHyperNetworkDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionLoraDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionModelDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionSamplerDao +import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao import com.shifthackz.aisdv1.storage.db.cache.entity.ServerConfigurationEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionHyperNetworkEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionModelEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntity +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity @Database( version = DB_VERSION, @@ -29,6 +31,7 @@ import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntit StableDiffusionLoraEntity::class, StableDiffusionHyperNetworkEntity::class, StableDiffusionEmbeddingEntity::class, + SwarmUiModelEntity::class, ], ) @TypeConverters( @@ -42,6 +45,7 @@ internal abstract class CacheDatabase : RoomDatabase() { abstract fun sdLoraDao(): StableDiffusionLoraDao abstract fun sdHyperNetworkDao(): StableDiffusionHyperNetworkDao abstract fun sdEmbeddingDao(): StableDiffusionEmbeddingDao + abstract fun swarmUiModelDao(): SwarmUiModelDao companion object { const val DB_VERSION = 1 diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt new file mode 100644 index 00000000..2a900c59 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.storage.db.cache.contract + +internal object SwarmUiModelContract { + const val TABLE = "swarm_models" + + const val ID = "id" + const val NAME = "name" + const val TITLE = "title" + const val AUTHOR = "author" +} diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt new file mode 100644 index 00000000..bda49f52 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.storage.db.cache.dao + +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import com.shifthackz.aisdv1.storage.db.cache.contract.SwarmUiModelContract +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +@Dao +interface SwarmUiModelDao { + + @Query("SELECT * FROM ${SwarmUiModelContract.TABLE}") + fun queryAll(): Single> + + @Insert(onConflict = OnConflictStrategy.REPLACE) + fun insertList(items: List): Completable + + @Query("DELETE FROM ${SwarmUiModelContract.TABLE}") + fun deleteAll(): Completable +} diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt new file mode 100644 index 00000000..f79f12b0 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.storage.db.cache.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.shifthackz.aisdv1.storage.db.cache.contract.SwarmUiModelContract + +@Entity(tableName = SwarmUiModelContract.TABLE) +data class SwarmUiModelEntity( + @PrimaryKey(autoGenerate = false) + @ColumnInfo(name = SwarmUiModelContract.ID) + val id: String, + @ColumnInfo(name = SwarmUiModelContract.NAME) + val name: String, + @ColumnInfo(name = SwarmUiModelContract.TITLE) + val title: String, + @ColumnInfo(name = SwarmUiModelContract.AUTHOR) + val author: String, +) diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt index f84172db..68255373 100755 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt @@ -43,6 +43,7 @@ val databaseModule = module { single { get().sdHyperNetworkDao() } single { get().sdEmbeddingDao() } single { get().serverConfigurationDao() } + single { get().swarmUiModelDao() } //endregion //region PERSISTENT DB DAOs From 21f3637932b08323d010df2ea8a2cf80829f0be2 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sun, 4 Aug 2024 22:39:08 +0300 Subject: [PATCH 03/11] Swarm UI LoRAs implementation --- .../imageprocessing/utils/Base64ImageUtils.kt | 5 ++ .../aisdv1/data/di/LocalDataSourceModule.kt | 6 +-- .../aisdv1/data/di/RemoteDataSourceModule.kt | 10 ++-- .../aisdv1/data/di/RepositoryModule.kt | 6 +-- ...lDataSource.kt => LorasLocalDataSource.kt} | 10 ++-- .../data/mappers/HuggingFaceModelMappers.kt | 6 +-- .../mappers/ImageToImagePayloadMappers.kt | 29 +++++++++++ .../data/mappers/LocalAiModelMappers.kt | 6 +-- .../data/mappers/StabilityAiEngineMappers.kt | 6 +-- .../StableDiffusionEmbeddingsMappers.kt | 2 +- .../StableDiffusionHyperNetworksMappers.kt | 6 +-- .../mappers/StableDiffusionLorasMappers.kt | 20 ++++---- .../mappers/StableDiffusionModelsMappers.kt | 6 +-- .../mappers/StableDiffusionSamplersMappers.kt | 6 +-- .../SwarmUiImageToImagePayloadMappers.kt | 27 ----------- .../data/mappers/SwarmUiModelsMappers.kt | 31 +++++++++--- .../SwarmUiTextToImagePayloadMappers.kt | 48 ------------------- .../data/mappers/TextToImagePayloadMappers.kt | 27 +++++++++++ .../data/preference/PreferenceManagerImpl.kt | 8 ++-- .../DownloadableModelRemoteDataSource.kt | 4 +- .../HuggingFaceModelsRemoteDataSource.kt | 4 +- .../StabilityAiEnginesRemoteDataSource.kt | 4 +- ...ableDiffusionEmbeddingsRemoteDataSource.kt | 4 +- ...eDiffusionHyperNetworksRemoteDataSource.kt | 4 +- .../StableDiffusionLorasRemoteDataSource.kt | 8 ++-- .../StableDiffusionModelsRemoteDataSource.kt | 4 +- ...StableDiffusionSamplersRemoteDataSource.kt | 4 +- .../remote/SwarmUiLorasRemoteDataSource.kt | 29 +++++++++++ .../remote/SwarmUiModelsRemoteDataSource.kt | 5 +- .../data/repository/LorasRepositoryImpl.kt | 38 +++++++++++++++ .../StableDiffusionLorasRepositoryImpl.kt | 20 -------- .../SwarmUiGenerationRepositoryImpl.kt | 4 +- ...StableDiffusionLorasLocalDataSourceTest.kt | 2 +- .../data/mocks/StableDiffusionLoraMocks.kt | 6 +-- .../preference/PreferenceManagerImplTest.kt | 6 +-- .../DownloadableModelRemoteDataSourceTest.kt | 4 +- ...ableDiffusionLorasRemoteDataSourceTest.kt} | 8 ++-- ...ImplTest.kt => LorasRepositoryImplTest.kt} | 15 +++--- .../domain/datasource/LorasDataSource.kt | 24 ++++++++++ .../StableDiffusionLorasDataSource.kt | 18 ------- .../aisdv1/domain/entity/Configuration.kt | 1 + .../{StableDiffusionLora.kt => LoRA.kt} | 2 +- .../domain/preference/PreferenceManager.kt | 6 +-- .../domain/repository/LorasRepository.kt | 11 +++++ .../StableDiffusionLorasRepository.kt | 11 ----- .../caching/DataPreLoaderUseCaseImpl.kt | 5 +- .../usecase/sdlora/FetchAndGetLorasUseCase.kt | 4 +- .../sdlora/FetchAndGetLorasUseCaseImpl.kt | 6 +-- .../settings/GetConfigurationUseCaseImpl.kt | 5 +- .../SetServerConfigurationUseCaseImpl.kt | 6 +-- .../splash/SplashNavigationUseCaseImpl.kt | 2 +- .../FetchAndGetSwarmUiModelsUseCase.kt | 2 +- .../FetchAndGetSwarmUiModelsUseCaseImpl.kt | 10 ++-- .../caching/DataPreLoaderUseCaseImplTest.kt | 20 ++++---- .../GetConfigurationUseCaseImplTest.kt | 2 +- .../SetServerConfigurationUseCaseImplTest.kt | 2 +- .../splash/SplashNavigationUseCaseImplTest.kt | 8 ++-- .../request/SwarmUiGenerationRequest.kt | 12 ++--- .../network/request/SwarmUiModelsRequest.kt | 2 + .../presentation/modal/extras/ExtrasScreen.kt | 1 + .../modal/extras/ExtrasViewModel.kt | 6 +-- .../widget/engine/EngineSelectionViewModel.kt | 8 ++-- .../widget/input/GenerationInputForm.kt | 3 +- .../widget/toolbar/GenearionBottomToolbar.kt | 42 ++++++++-------- .../mocks/StableDiffusionLoraMocks.kt | 6 +-- 65 files changed, 364 insertions(+), 299 deletions(-) rename data/src/main/java/com/shifthackz/aisdv1/data/local/{StableDiffusionLorasLocalDataSource.kt => LorasLocalDataSource.kt} (64%) delete mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt delete mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt delete mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt rename data/src/test/java/com/shifthackz/aisdv1/data/remote/{StableDiffusionLorasRemoteDataSourceTest.kt => StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt} (92%) rename data/src/test/java/com/shifthackz/aisdv1/data/repository/{StableDiffusionLorasRepositoryImplTest.kt => LorasRepositoryImplTest.kt} (90%) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt rename domain/src/main/java/com/shifthackz/aisdv1/domain/entity/{StableDiffusionLora.kt => LoRA.kt} (78%) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt diff --git a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt index 5865d72d..85557818 100644 --- a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt +++ b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt @@ -15,3 +15,8 @@ fun bitmapToBase64(bitmap: Bitmap): String { bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream) return Base64.encodeToString(outputStream.toByteArray(), Base64.DEFAULT) } + +fun base64DefaultToNoWrap(base64Default: String): String { + val byteArray = Base64.decode(base64Default, Base64.DEFAULT) + return Base64.encodeToString(byteArray, Base64.NO_WRAP) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt index 4d4a8012..4373902e 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt @@ -5,22 +5,22 @@ import com.shifthackz.aisdv1.data.gateway.mediastore.MediaStoreGatewayFactory import com.shifthackz.aisdv1.data.local.DownloadableModelLocalDataSource import com.shifthackz.aisdv1.data.local.GenerationResultLocalDataSource import com.shifthackz.aisdv1.data.local.HuggingFaceModelsLocalDataSource +import com.shifthackz.aisdv1.data.local.LorasLocalDataSource import com.shifthackz.aisdv1.data.local.ServerConfigurationLocalDataSource import com.shifthackz.aisdv1.data.local.StabilityAiCreditsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionEmbeddingsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionHyperNetworksLocalDataSource -import com.shifthackz.aisdv1.data.local.StableDiffusionLorasLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionModelsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionSamplersLocalDataSource import com.shifthackz.aisdv1.data.local.SwarmUiModelsLocalDataSource import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource @@ -37,7 +37,7 @@ val localDataSourceModule = module { single { StabilityAiCreditsLocalDataSource() } factoryOf(::StableDiffusionModelsLocalDataSource) bind StableDiffusionModelsDataSource.Local::class factoryOf(::StableDiffusionSamplersLocalDataSource) bind StableDiffusionSamplersDataSource.Local::class - factoryOf(::StableDiffusionLorasLocalDataSource) bind StableDiffusionLorasDataSource.Local::class + factoryOf(::LorasLocalDataSource) bind LorasDataSource.Local::class factoryOf(::StableDiffusionHyperNetworksLocalDataSource) bind StableDiffusionHyperNetworksDataSource.Local::class factoryOf(::StableDiffusionEmbeddingsLocalDataSource) bind StableDiffusionEmbeddingsDataSource.Local::class factoryOf(::SwarmUiModelsLocalDataSource) bind SwarmUiModelsDataSource.Local::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt index 3a834ea8..25cdba92 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt @@ -21,12 +21,14 @@ import com.shifthackz.aisdv1.data.remote.StableDiffusionLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionSamplersRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiGenerationRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiSessionDataSourceImpl import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource import com.shifthackz.aisdv1.domain.datasource.OpenAiGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.RandomImageDataSource import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource @@ -36,7 +38,6 @@ import com.shifthackz.aisdv1.domain.datasource.StabilityAiGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource @@ -58,9 +59,9 @@ val remoteDataSourceModule = module { ServerUrlProvider { endpoint -> val prefs = get() val chain = if (prefs.source == ServerSource.SWARM_UI) { - Single.fromCallable(prefs::swarmServerUrl) + Single.fromCallable(prefs::swarmUiServerUrl) } else { - Single.fromCallable(prefs::serverUrl) + Single.fromCallable(prefs::automatic1111serverUrl) } chain .map(String::fixUrlSlashes) @@ -74,10 +75,11 @@ val remoteDataSourceModule = module { factoryOf(::SwarmUiSessionDataSourceImpl) bind SwarmUiSessionDataSource::class factoryOf(::SwarmUiGenerationRemoteDataSource) bind SwarmUiGenerationDataSource.Remote::class factoryOf(::SwarmUiModelsRemoteDataSource) bind SwarmUiModelsDataSource.Remote::class + factoryOf(::SwarmUiLorasRemoteDataSource) bind LorasDataSource.Remote.SwarmUi::class factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class - factoryOf(::StableDiffusionLorasRemoteDataSource) bind StableDiffusionLorasDataSource.Remote::class + factoryOf(::StableDiffusionLorasRemoteDataSource) bind LorasDataSource.Remote.Automatic1111::class factoryOf(::StableDiffusionHyperNetworksRemoteDataSource) bind StableDiffusionHyperNetworksDataSource.Remote::class factoryOf(::StableDiffusionEmbeddingsRemoteDataSource) bind StableDiffusionEmbeddingsDataSource.Remote::class factoryOf(::ServerConfigurationRemoteDataSource) bind ServerConfigurationDataSource.Remote::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index 6e665fb3..3d48dd83 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl @@ -17,7 +18,6 @@ import com.shifthackz.aisdv1.data.repository.StabilityAiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionEmbeddingsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionHyperNetworksRepositoryImpl -import com.shifthackz.aisdv1.data.repository.StableDiffusionLorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionSamplersRepositoryImpl import com.shifthackz.aisdv1.data.repository.SwarmUiGenerationRepositoryImpl @@ -30,6 +30,7 @@ import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.RandomImageRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository @@ -39,7 +40,6 @@ import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository @@ -72,7 +72,7 @@ val repositoryModule = module { factoryOf(::StableDiffusionGenerationRepositoryImpl) bind StableDiffusionGenerationRepository::class factoryOf(::StableDiffusionModelsRepositoryImpl) bind StableDiffusionModelsRepository::class factoryOf(::StableDiffusionSamplersRepositoryImpl) bind StableDiffusionSamplersRepository::class - factoryOf(::StableDiffusionLorasRepositoryImpl) bind StableDiffusionLorasRepository::class + factoryOf(::LorasRepositoryImpl) bind LorasRepository::class factoryOf(::StableDiffusionHyperNetworksRepositoryImpl) bind StableDiffusionHyperNetworksRepository::class factoryOf(::StableDiffusionEmbeddingsRepositoryImpl) bind StableDiffusionEmbeddingsRepository::class factoryOf(::ServerConfigurationRepositoryImpl) bind ServerConfigurationRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSource.kt similarity index 64% rename from data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt rename to data/src/main/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSource.kt index e130ade2..e38771e6 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSource.kt @@ -2,20 +2,20 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionLoraDao import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity -internal class StableDiffusionLorasLocalDataSource( +internal class LorasLocalDataSource( private val dao: StableDiffusionLoraDao, -) : StableDiffusionLorasDataSource.Local { +) : LorasDataSource.Local { override fun getLoras() = dao .queryAll() .map(List::mapEntityToDomain) - override fun insertLoras(loras: List) = dao + override fun insertLoras(loras: List) = dao .deleteAll() .andThen(dao.insertList(loras.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt index 40c5d114..3ed4ad39 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.HuggingFaceModelRaw import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(HuggingFaceModelRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(HuggingFaceModelRaw::mapRawToCheckpointDomain) -fun HuggingFaceModelRaw.mapRawToDomain(): HuggingFaceModel = with(this) { +fun HuggingFaceModelRaw.mapRawToCheckpointDomain(): HuggingFaceModel = with(this) { HuggingFaceModel( id = id ?: "", name = name ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt index 75f15ac7..8e93311c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.data.mappers +import com.shifthackz.aisdv1.core.common.math.roundTo +import com.shifthackz.aisdv1.core.imageprocessing.utils.base64DefaultToNoWrap import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.StabilityAiClipGuidance @@ -7,9 +9,11 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiStylePreset import com.shifthackz.aisdv1.network.request.HordeGenerationAsyncRequest import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest import com.shifthackz.aisdv1.network.request.ImageToImageRequest +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest import com.shifthackz.aisdv1.network.response.SdGenerationResponse import java.util.Date +//region PAYLOAD --> REQUEST fun ImageToImagePayload.mapToRequest(): ImageToImageRequest = with(this) { ImageToImageRequest( initImages = listOf(base64Image), @@ -93,6 +97,30 @@ fun ImageToImagePayload.mapToStabilityAiRequest() = with(this) { } } +fun ImageToImagePayload.mapToSwarmUiRequest( + sessionId: String, + swarmUiModel: String, +): SwarmUiGenerationRequest = with(this) { + SwarmUiGenerationRequest( + sessionId = sessionId, + model = swarmUiModel, + initImage = "data:image/png;base64,${base64DefaultToNoWrap(base64Image)}", + initImageCreativity = denoisingStrength.roundTo(2).toString(), + images = 1, + prompt = prompt, + negativePrompt = negativePrompt, + width = width, + height = height, + seed = seed.trim().ifEmpty { null }, + variationSeed = subSeed.trim().ifEmpty { null }, + variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), + cfgScale = cfgScale, + steps = samplingSteps, + ) +} +//endregion + +//region RESPONSE --> RESULT fun Pair.mapToAiGenResult(): AiGenerationResult = let { (payload, response) -> AiGenerationResult( @@ -140,3 +168,4 @@ fun Pair.mapCloudToAiGenResult(): AiGenerationResul subSeedStrength = payload.subSeedStrength, ) } +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt index 7b8fbca6..d9261085 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(DownloadableModelResponse::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(DownloadableModelResponse::mapRawToCheckpointDomain) -fun DownloadableModelResponse.mapRawToDomain(): LocalAiModel = with(this) { +fun DownloadableModelResponse.mapRawToCheckpointDomain(): LocalAiModel = with(this) { LocalAiModel( id = id ?: "", name = name ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt index d6ba31f3..44999ca4 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt @@ -4,10 +4,10 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine import com.shifthackz.aisdv1.network.model.StabilityAiEngineRaw //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(StabilityAiEngineRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StabilityAiEngineRaw::mapRawToCheckpointDomain) -fun StabilityAiEngineRaw.mapRawToDomain(): StabilityAiEngine = with(this) { +fun StabilityAiEngineRaw.mapRawToCheckpointDomain(): StabilityAiEngine = with(this) { StabilityAiEngine(id ?: "", name ?: "") } //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt index 554d1efb..216dc124 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt @@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity //region RAW -> DOMAIN -fun SdEmbeddingsResponse.mapRawToDomain(): List = +fun SdEmbeddingsResponse.mapRawToCheckpointDomain(): List = loaded?.keys?.map(::StableDiffusionEmbedding) ?: emptyList() //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt index 646d4eb7..0690d860 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.StableDiffusionHyperNetworkRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionHyperNetworkEntity //region RAW -> DOMAIN -fun List.mapRawToDomain(): List = - map(StableDiffusionHyperNetworkRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StableDiffusionHyperNetworkRaw::mapRawToCheckpointDomain) -fun StableDiffusionHyperNetworkRaw.mapRawToDomain(): StableDiffusionHyperNetwork = with(this) { +fun StableDiffusionHyperNetworkRaw.mapRawToCheckpointDomain(): StableDiffusionHyperNetwork = with(this) { StableDiffusionHyperNetwork( name = name ?: "", path = path ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt index f22c2018..ab382998 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt @@ -1,15 +1,15 @@ package com.shifthackz.aisdv1.data.mappers -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.network.model.StableDiffusionLoraRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity //region RAW --> DOMAIN -fun List.mapToDomain(): List = +fun List.mapToDomain(): List = map(StableDiffusionLoraRaw::mapToDomain) -fun StableDiffusionLoraRaw.mapToDomain(): StableDiffusionLora = with(this) { - StableDiffusionLora( +fun StableDiffusionLoraRaw.mapToDomain(): LoRA = with(this) { + LoRA( name = name ?: "", alias = alias ?: "", path = path ?: "", @@ -18,10 +18,10 @@ fun StableDiffusionLoraRaw.mapToDomain(): StableDiffusionLora = with(this) { //endregion //region DOMAIN -> ENTITY -fun List.mapDomainToEntity(): List = - map(StableDiffusionLora::mapDomainToEntity) +fun List.mapDomainToEntity(): List = + map(LoRA::mapDomainToEntity) -fun StableDiffusionLora.mapDomainToEntity(): StableDiffusionLoraEntity = with(this) { +fun LoRA.mapDomainToEntity(): StableDiffusionLoraEntity = with(this) { StableDiffusionLoraEntity( id = name, name = name, @@ -32,11 +32,11 @@ fun StableDiffusionLora.mapDomainToEntity(): StableDiffusionLoraEntity = with(th //endregion //region ENTITY -> DOMAIN -fun List.mapEntityToDomain(): List = +fun List.mapEntityToDomain(): List = map(StableDiffusionLoraEntity::mapEntityToDomain) -fun StableDiffusionLoraEntity.mapEntityToDomain(): StableDiffusionLora = with(this) { - StableDiffusionLora( +fun StableDiffusionLoraEntity.mapEntityToDomain(): LoRA = with(this) { + LoRA( name = name, alias = alias, path = path, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt index 8128cdf6..b9d03c2d 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.StableDiffusionModelRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionModelEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(StableDiffusionModelRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StableDiffusionModelRaw::mapRawToCheckpointDomain) -fun StableDiffusionModelRaw.mapRawToDomain(): StableDiffusionModel = with(this) { +fun StableDiffusionModelRaw.mapRawToCheckpointDomain(): StableDiffusionModel = with(this) { StableDiffusionModel( title = title ?: "", modelName = modelName ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt index af87956c..dd932fa8 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.StableDiffusionSamplerRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(StableDiffusionSamplerRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StableDiffusionSamplerRaw::mapRawToCheckpointDomain) -fun StableDiffusionSamplerRaw.mapRawToDomain(): StableDiffusionSampler = with(this) { +fun StableDiffusionSamplerRaw.mapRawToCheckpointDomain(): StableDiffusionSampler = with(this) { StableDiffusionSampler( name = name ?: "", aliases = aliases ?: emptyList(), diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt deleted file mode 100644 index 8514bef4..00000000 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiImageToImagePayloadMappers.kt +++ /dev/null @@ -1,27 +0,0 @@ -package com.shifthackz.aisdv1.data.mappers - -import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload -import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest - -fun ImageToImagePayload.mapToSwarmUiRequest( - sessionId: String, - swarmUiModel: String, -): SwarmUiGenerationRequest = with(this) { - SwarmUiGenerationRequest( - sessionId = sessionId, - model = swarmUiModel, - initImage = "data:image/png;base64,${base64Image.trim('\n').trim('\u003d')}", -// initImage = base64Image, -// initImage = base64Image, - images = 1, - prompt = prompt, - negativePrompt = negativePrompt, - width = width, - height = height, - seed = seed.trim().ifEmpty { null }, - variationSeed = subSeed.trim().ifEmpty { null }, - variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), - cfgScale = cfgScale, - steps = samplingSteps, - ) -} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt index 1926fd3f..d3e5fd27 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt @@ -1,18 +1,19 @@ package com.shifthackz.aisdv1.data.mappers +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.domain.entity.SwarmUiModel import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity -//region RAW --> DOMAIN -fun SwarmUiModelsResponse.mapRawToDomain(): List = with(this) { - this.files?.mapRawToDomain() ?: emptyList() +//region RAW --> CHECKPOINT DOMAIN +fun SwarmUiModelsResponse.mapRawToCheckpointDomain(): List = with(this) { + this.files?.mapRawToCheckpointDomain() ?: emptyList() } -fun List.mapRawToDomain(): List = map(SwarmUiModelRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = map(SwarmUiModelRaw::mapRawToCheckpointDomain) -fun SwarmUiModelRaw.mapRawToDomain(): SwarmUiModel = with(this) { +fun SwarmUiModelRaw.mapRawToCheckpointDomain(): SwarmUiModel = with(this) { SwarmUiModel( name = name ?: "", title = title ?: "", @@ -21,7 +22,23 @@ fun SwarmUiModelRaw.mapRawToDomain(): SwarmUiModel = with(this) { } //endregion -//region DOMAIN --> ENTITY +//region RAW --> LORA DOMAIN +fun SwarmUiModelsResponse.mapRawToLoraDomain(): List = with(this) { + this.files?.mapRawToLoraDomain() ?: emptyList() +} + +fun List.mapRawToLoraDomain(): List = map(SwarmUiModelRaw::mapRawToLoraDomain) + +fun SwarmUiModelRaw.mapRawToLoraDomain(): LoRA = with(this) { + LoRA( + name = name ?: "", + alias = title ?: "", + path = "", + ) +} +//endregion + +//region CHECKPOINT DOMAIN --> ENTITY fun List.mapDomainToEntity(): List = map(SwarmUiModel::mapDomainToEntity) fun SwarmUiModel.mapDomainToEntity(): SwarmUiModelEntity = with(this) { @@ -34,7 +51,7 @@ fun SwarmUiModel.mapDomainToEntity(): SwarmUiModelEntity = with(this) { } //endregion -//region ENTITY --> DOMAIN +//region ENTITY --> CHECKPOINT DOMAIN fun List.mapEntityToDomain(): List = map(SwarmUiModelEntity::mapEntityToDomain) fun SwarmUiModelEntity.mapEntityToDomain(): SwarmUiModel = with(this) { diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt deleted file mode 100644 index f3ef2863..00000000 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiTextToImagePayloadMappers.kt +++ /dev/null @@ -1,48 +0,0 @@ -package com.shifthackz.aisdv1.data.mappers - -import com.shifthackz.aisdv1.domain.entity.TextToImagePayload -import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest - -fun TextToImagePayload.mapToSwarmUiRequest( - sessionId: String, - swarmUiModel: String, -): SwarmUiGenerationRequest = with(this) { - SwarmUiGenerationRequest( - sessionId = sessionId, - model = swarmUiModel, - initImage = null, - images = 1, - prompt = prompt, - negativePrompt = negativePrompt, - width = width, - height = height, - seed = seed.trim().ifEmpty { null }, - variationSeed = subSeed.trim().ifEmpty { null }, - variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), - cfgScale = cfgScale, - steps = samplingSteps, - ) -} -// -//fun Pair.mapToAiGenResult(): AiGenerationResult = -// let { (payload, response) -> -// AiGenerationResult( -// id = 0L, -// image = response.images?.firstOrNull() ?: "", -// inputImage = "", -// createdAt = Date(), -// type = AiGenerationResult.Type.TEXT_TO_IMAGE, -// denoisingStrength = 0f, -// prompt = payload.prompt, -// negativePrompt = payload.negativePrompt, -// width = payload.width, -// height = payload.height, -// samplingSteps = payload.samplingSteps, -// cfgScale = payload.cfgScale, -// restoreFaces = payload.restoreFaces, -// sampler = payload.sampler, -// seed = payload.seed, -// subSeed = payload.subSeed, -// subSeedStrength = payload.subSeedStrength, -// ) -// } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt index 37dba58a..36c76b21 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt @@ -10,10 +10,12 @@ import com.shifthackz.aisdv1.network.request.HordeGenerationAsyncRequest import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest import com.shifthackz.aisdv1.network.request.OpenAiRequest import com.shifthackz.aisdv1.network.request.StabilityTextToImageRequest +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest import com.shifthackz.aisdv1.network.request.TextToImageRequest import com.shifthackz.aisdv1.network.response.SdGenerationResponse import java.util.Date +//region PAYLOAD --> REQUEST fun TextToImagePayload.mapToRequest(): TextToImageRequest = with(this) { TextToImageRequest( prompt = prompt, @@ -95,6 +97,30 @@ fun TextToImagePayload.mapToStabilityAiRequest(): StabilityTextToImageRequest = ) } +fun TextToImagePayload.mapToSwarmUiRequest( + sessionId: String, + swarmUiModel: String, +): SwarmUiGenerationRequest = with(this) { + SwarmUiGenerationRequest( + sessionId = sessionId, + model = swarmUiModel, + initImage = null, + initImageCreativity = null, + images = 1, + prompt = prompt, + negativePrompt = negativePrompt, + width = width, + height = height, + seed = seed.trim().ifEmpty { null }, + variationSeed = subSeed.trim().ifEmpty { null }, + variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), + cfgScale = cfgScale, + steps = samplingSteps, + ) +} +//endregion + +//region RESPONSE --> RESULT fun Pair.mapToAiGenResult(): AiGenerationResult = let { (payload, response) -> AiGenerationResult( @@ -165,3 +191,4 @@ fun Pair.mapLocalDiffusionToAiGenResult(): AiGenerat subSeedStrength = payload.subSeedStrength, ) } +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index c1651153..ad485239 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -20,21 +20,21 @@ class PreferenceManagerImpl( private val preferencesChangedSubject: BehaviorSubject = BehaviorSubject.createDefault(Unit) - override var serverUrl: String + override var automatic1111serverUrl: String get() = (preferences.getString(KEY_SERVER_URL, "") ?: "").fixUrlSlashes() set(value) = preferences.edit() .putString(KEY_SERVER_URL, value.fixUrlSlashes()) .apply() .also { onPreferencesChanged() } - override var swarmServerUrl: String + override var swarmUiServerUrl: String get() = (preferences.getString(KEY_SWARM_SERVER_URL, "") ?: "").fixUrlSlashes() set(value) = preferences.edit() .putString(KEY_SWARM_SERVER_URL, value.fixUrlSlashes()) .apply() .also { onPreferencesChanged() } - override var swarmModel: String + override var swarmUiModel: String get() = preferences.getString(KEY_SWARM_MODEL, "") ?: "" set(value) = preferences .edit() @@ -207,7 +207,7 @@ class PreferenceManagerImpl( .toFlowable(BackpressureStrategy.LATEST) .map { Settings( - serverUrl = serverUrl, + serverUrl = automatic1111serverUrl, sdModel = sdModel, demoMode = demoMode, monitorConnectivity = monitorConnectivity, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index 9241dcc6..a2e6b193 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.file.unzip -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi @@ -18,7 +18,7 @@ internal class DownloadableModelRemoteDataSource( override fun fetch() = api .fetchDownloadableModels() - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) override fun download(id: String, url: String): Observable = Completable diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt index a336f341..73056260 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel import com.shifthackz.aisdv1.network.api.sdai.HuggingFaceModelsApi @@ -12,6 +12,6 @@ internal class HuggingFaceModelsRemoteDataSource( override fun fetchHuggingFaceModels() = api .fetchHuggingFaceModels() - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) .onErrorReturn { listOf(HuggingFaceModel.default) } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt index 7fdcd173..b6d68121 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.StabilityAiEnginesDataSource import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi @@ -13,5 +13,5 @@ internal class StabilityAiEnginesRemoteDataSource( override fun fetch(): Single> = api .fetchEngines() - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt index 0da9eb8b..4c211af3 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionEmbeddingsRemoteDataSource( override fun fetchEmbeddings() = serverUrlProvider(PATH_EMBEDDINGS) .flatMap(api::fetchEmbeddings) - .map(SdEmbeddingsResponse::mapRawToDomain) + .map(SdEmbeddingsResponse::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt index c72a3beb..059d8dd7 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionHyperNetworksRemoteDataSource( override fun fetchHyperNetworks() = serverUrlProvider(PATH_HYPER_NETWORKS) .flatMap(api::fetchHyperNetworks) - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt index 7019e6ed..dd0d546d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt @@ -2,8 +2,8 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mappers.mapToDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi.Companion.PATH_LORAS import com.shifthackz.aisdv1.network.model.StableDiffusionLoraRaw @@ -12,9 +12,9 @@ import io.reactivex.rxjava3.core.Single internal class StableDiffusionLorasRemoteDataSource( private val serverUrlProvider: ServerUrlProvider, private val api: Automatic1111RestApi, -) : StableDiffusionLorasDataSource.Remote { +) : LorasDataSource.Remote.Automatic1111 { - override fun fetchLoras(): Single> = serverUrlProvider(PATH_LORAS) + override fun fetchLoras(): Single> = serverUrlProvider(PATH_LORAS) .flatMap(api::fetchLoras) .map(List::mapToDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt index e3a36fea..46485857 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionModelsRemoteDataSource( override fun fetchSdModels() = serverUrlProvider(PATH_SD_MODELS) .flatMap(api::fetchSdModels) - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt index ad1200ec..44056fa7 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionSamplersRemoteDataSource( override fun fetchSamplers() = serverUrlProvider(PATH_SAMPLERS) .flatMap(api::fetchSamplers) - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } \ No newline at end of file diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt new file mode 100644 index 00000000..0b6b14df --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapRawToLoraDomain +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_MODELS +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiLorasRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, +) : LorasDataSource.Remote.SwarmUi { + + override fun fetchLoras(sessionId: String): Single> = serverUrlProvider(PATH_MODELS) + .flatMap { url -> + val request = SwarmUiModelsRequest( + sessionId = sessionId, + subType = "LoRA", + path = "", + depth = 3, + ) + api.fetchModels(url, request) + } + .map(SwarmUiModelsResponse::mapRawToLoraDomain) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt index 50e0605d..c2e57c04 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource import com.shifthackz.aisdv1.domain.entity.SwarmUiModel @@ -20,10 +20,11 @@ internal class SwarmUiModelsRemoteDataSource( .flatMap { url -> val request = SwarmUiModelsRequest( sessionId = sessionId, + subType = "Stable-Diffusion", path = "", depth = 3, ) api.fetchModels(url, request) } - .map(SwarmUiModelsResponse::mapRawToDomain) + .map(SwarmUiModelsResponse::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt new file mode 100644 index 00000000..5e0fecaa --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt @@ -0,0 +1,38 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.LorasRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class LorasRepositoryImpl( + private val rdsA1111: LorasDataSource.Remote.Automatic1111, + private val rdsSwarm: LorasDataSource.Remote.SwarmUi, + private val swarmSession: SwarmUiSessionDataSource, + private val lds: LorasDataSource.Local, + private val preferenceManager: PreferenceManager, +) : LorasRepository { + + override fun fetchLoras(): Completable = when (preferenceManager.source) { + ServerSource.AUTOMATIC1111 -> rdsA1111 + .fetchLoras() + .flatMapCompletable(lds::insertLoras) + + ServerSource.SWARM_UI -> swarmSession + .getSessionId() + .flatMap(rdsSwarm::fetchLoras) + .flatMapCompletable(lds::insertLoras) + + else -> Completable.complete() + } + + override fun fetchAndGetLoras(): Single> = fetchLoras() + .onErrorComplete() + .andThen(getLoras()) + + override fun getLoras(): Single> = lds.getLoras() +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt deleted file mode 100644 index 02e4d841..00000000 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt +++ /dev/null @@ -1,20 +0,0 @@ -package com.shifthackz.aisdv1.data.repository - -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository - -internal class StableDiffusionLorasRepositoryImpl( - private val remoteDataSource: StableDiffusionLorasDataSource.Remote, - private val localDataSource: StableDiffusionLorasDataSource.Local, -) : StableDiffusionLorasRepository { - - override fun fetchLoras() = remoteDataSource - .fetchLoras() - .flatMapCompletable(localDataSource::insertLoras) - - override fun fetchAndGetLoras() = fetchLoras() - .onErrorComplete() - .andThen(getLoras()) - - override fun getLoras() = localDataSource.getLoras() -} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt index e4919e26..1ea98007 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt @@ -41,7 +41,7 @@ internal class SwarmUiGenerationRepositoryImpl( .flatMap { sessionId -> remoteDataSource.textToImage( sessionId = sessionId, - model = preferenceManager.swarmModel, + model = preferenceManager.swarmUiModel, payload = payload, ) } @@ -52,7 +52,7 @@ internal class SwarmUiGenerationRepositoryImpl( .flatMap { sessionId -> remoteDataSource.imageToImage( sessionId = sessionId, - model = preferenceManager.swarmModel, + model = preferenceManager.swarmUiModel, payload = payload, ) } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt index 6a7baaa3..8ddec896 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt @@ -14,7 +14,7 @@ class StableDiffusionLorasLocalDataSourceTest { private val stubException = Throwable("Database error.") private val stubDao = mockk() - private val localDataSource = StableDiffusionLorasLocalDataSource(stubDao) + private val localDataSource = LorasLocalDataSource(stubDao) @Test fun `given attempt to get loras, dao returns list, expected valid domain model list value`() { diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt index 07b40191..a6531aa2 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt @@ -1,14 +1,14 @@ package com.shifthackz.aisdv1.data.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA val mockStableDiffusionLoras = listOf( - StableDiffusionLora( + LoRA( name = "name_5598", alias = "alias_5598", path = "/unknown", ), - StableDiffusionLora( + LoRA( name = "name_151297", alias = "alias_151297", path = "/unknown", diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index 0da16f96..6a0948a8 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -71,14 +71,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) .thenReturn("") - Assert.assertEquals("", preferenceManager.serverUrl) + Assert.assertEquals("", preferenceManager.automatic1111serverUrl) whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) .thenReturn("https://192.168.0.1:7860") - preferenceManager.serverUrl = "https://192.168.0.1:7860" + preferenceManager.automatic1111serverUrl = "https://192.168.0.1:7860" - Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.serverUrl) + Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.automatic1111serverUrl) preferenceManager .observe() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt index c8e7339d..c97b0a9e 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt @@ -3,7 +3,7 @@ package com.shifthackz.aisdv1.data.remote import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.mocks.mockDownloadableModelsResponse import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import io.reactivex.rxjava3.core.Single @@ -25,7 +25,7 @@ class DownloadableModelRemoteDataSourceTest { whenever(stubApi.fetchDownloadableModels()) .thenReturn(Single.just(mockDownloadableModelsResponse)) - val expected = mockDownloadableModelsResponse.mapRawToDomain() + val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain() remoteDataSource .fetch() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt similarity index 92% rename from data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt index 3c999085..7b0ae55c 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoraRaw import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import io.mockk.every import io.mockk.mockk @@ -10,7 +10,7 @@ import io.reactivex.rxjava3.core.Single import org.junit.Before import org.junit.Test -class StableDiffusionLorasRemoteDataSourceTest { +class StableDiffusionStableDiffusionLorasRemoteDataSourceTest { private val stubException = Throwable("Internal server error.") private val stubUrlProvider = mockk() @@ -18,7 +18,7 @@ class StableDiffusionLorasRemoteDataSourceTest { private val remoteDataSource = StableDiffusionLorasRemoteDataSource( serverUrlProvider = stubUrlProvider, - api = stubApi, + automatic1111Api = stubApi, ) @Before @@ -39,7 +39,7 @@ class StableDiffusionLorasRemoteDataSourceTest { .test() .assertNoErrors() .assertValue { loras -> - loras is List + loras is List && loras.size == mockStableDiffusionLoraRaw.size } .await() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt similarity index 90% rename from data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt index 5fbf1aaa..97513ffe 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt @@ -1,23 +1,22 @@ package com.shifthackz.aisdv1.data.repository import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoras -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModels -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource import io.mockk.every import io.mockk.mockk import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class StableDiffusionLorasRepositoryImplTest { +class LorasRepositoryImplTest { private val stubException = Throwable("Something went wrong.") - private val stubRemoteDataSource = mockk() - private val stubLocalDataSource = mockk() + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() - private val repository = StableDiffusionLorasRepositoryImpl( - remoteDataSource = stubRemoteDataSource, - localDataSource = stubLocalDataSource, + private val repository = LorasRepositoryImpl( + rdsA1111 = stubRemoteDataSource, + lds = stubLocalDataSource, ) @Test diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt new file mode 100644 index 00000000..fd61ed0f --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.LoRA +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +sealed interface LorasDataSource { + + sealed interface Remote : LorasDataSource { + + interface Automatic1111 : Remote { + fun fetchLoras(): Single> + } + + interface SwarmUi : Remote { + fun fetchLoras(sessionId: String): Single> + } + } + + interface Local : LorasDataSource { + fun getLoras(): Single> + fun insertLoras(loras: List): Completable + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt deleted file mode 100644 index d8db5e7b..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt +++ /dev/null @@ -1,18 +0,0 @@ -package com.shifthackz.aisdv1.domain.datasource - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -sealed interface StableDiffusionLorasDataSource { - - interface Remote : StableDiffusionLorasDataSource { - fun fetchLoras(): Single> - } - - interface Local : StableDiffusionLorasDataSource { - fun getLoras(): Single> - - fun insertLoras(loras: List): Completable - } -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index ab8a4200..e644c6f8 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -5,6 +5,7 @@ import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials data class Configuration( val serverUrl: String = "", val swarmUiUrl: String = "", + val swarmUiModel: String = "", val demoMode: Boolean = false, val source: ServerSource = ServerSource.AUTOMATIC1111, val hordeApiKey: String = "", diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionLora.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LoRA.kt similarity index 78% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionLora.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LoRA.kt index 25906524..54017459 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionLora.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LoRA.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.domain.entity -data class StableDiffusionLora( +data class LoRA( val name: String, val alias: String, val path: String, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index 8f057184..c0f79422 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -5,9 +5,9 @@ import com.shifthackz.aisdv1.domain.entity.Settings import io.reactivex.rxjava3.core.Flowable interface PreferenceManager { - var serverUrl: String - var swarmServerUrl: String - var swarmModel: String + var automatic1111serverUrl: String + var swarmUiServerUrl: String + var swarmUiModel: String var demoMode: Boolean var monitorConnectivity: Boolean var autoSaveAiResults: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt new file mode 100644 index 00000000..4beff109 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.LoRA +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface LorasRepository { + fun fetchLoras(): Completable + fun fetchAndGetLoras(): Single> + fun getLoras(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt deleted file mode 100644 index 9a0d6f03..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt +++ /dev/null @@ -1,11 +0,0 @@ -package com.shifthackz.aisdv1.domain.repository - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -interface StableDiffusionLorasRepository { - fun fetchLoras(): Completable - fun fetchAndGetLoras(): Single> - fun getLoras(): Single> -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt index c6eb7f77..a711718c 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt @@ -1,18 +1,17 @@ package com.shifthackz.aisdv1.domain.usecase.caching +import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository -import io.reactivex.rxjava3.core.Completable internal class DataPreLoaderUseCaseImpl( private val serverConfigurationRepository: ServerConfigurationRepository, private val sdModelsRepository: StableDiffusionModelsRepository, private val sdSamplersRepository: StableDiffusionSamplersRepository, - private val sdLorasRepository: StableDiffusionLorasRepository, + private val sdLorasRepository: LorasRepository, private val sdHyperNetworksRepository: StableDiffusionHyperNetworksRepository, private val sdEmbeddingsRepository: StableDiffusionEmbeddingsRepository, ) : DataPreLoaderUseCase { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt index 820df950..65936385 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.sdlora -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA import io.reactivex.rxjava3.core.Single interface FetchAndGetLorasUseCase { - operator fun invoke(): Single> + operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt index a6c6d3eb..fd9bc42f 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt @@ -1,10 +1,10 @@ package com.shifthackz.aisdv1.domain.usecase.sdlora -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository +import com.shifthackz.aisdv1.domain.repository.LorasRepository internal class FetchAndGetLorasUseCaseImpl( - private val stableDiffusionLorasRepository: StableDiffusionLorasRepository, + private val lorasRepository: LorasRepository, ) : FetchAndGetLorasUseCase { - override fun invoke() = stableDiffusionLorasRepository.getLoras() + override fun invoke() = lorasRepository.fetchAndGetLoras() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index bcad1249..438ce9b4 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -12,8 +12,9 @@ internal class GetConfigurationUseCaseImpl( override fun invoke(): Single = Single.just( Configuration( - serverUrl = preferenceManager.serverUrl, - swarmUiUrl = preferenceManager.swarmServerUrl, + serverUrl = preferenceManager.automatic1111serverUrl, + swarmUiUrl = preferenceManager.swarmUiServerUrl, + swarmUiModel = preferenceManager.swarmUiModel, demoMode = preferenceManager.demoMode, source = preferenceManager.source, hordeApiKey = preferenceManager.hordeApiKey, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index 19c2d794..fab9973c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -14,8 +14,9 @@ internal class SetServerConfigurationUseCaseImpl( Completable.fromAction { authorizationStore.storeAuthorizationCredentials(configuration.authCredentials) preferenceManager.source = configuration.source - preferenceManager.serverUrl = configuration.serverUrl - preferenceManager.swarmServerUrl = configuration.swarmUiUrl + preferenceManager.automatic1111serverUrl = configuration.serverUrl + preferenceManager.swarmUiServerUrl = configuration.swarmUiUrl + preferenceManager.swarmUiModel = configuration.swarmUiModel preferenceManager.demoMode = configuration.demoMode preferenceManager.hordeApiKey = configuration.hordeApiKey preferenceManager.openAiApiKey = configuration.openAiApiKey @@ -24,6 +25,5 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId preferenceManager.localModelId = configuration.localModelId - println("SWARM - set ${preferenceManager.swarmServerUrl}") } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt index f50bfc26..b550470d 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt @@ -15,7 +15,7 @@ internal class SplashNavigationUseCaseImpl( Action.LAUNCH_SERVER_SETUP } - preferenceManager.serverUrl.isEmpty() + preferenceManager.automatic1111serverUrl.isEmpty() && preferenceManager.source == ServerSource.AUTOMATIC1111 -> { Action.LAUNCH_SERVER_SETUP } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt index c672c93a..35b555fb 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt @@ -4,5 +4,5 @@ import com.shifthackz.aisdv1.domain.entity.SwarmUiModel import io.reactivex.rxjava3.core.Single interface FetchAndGetSwarmUiModelsUseCase { - operator fun invoke(): Single>> + operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt index d32191d6..556067ab 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt @@ -10,14 +10,12 @@ internal class FetchAndGetSwarmUiModelsUseCaseImpl( private val repository: SwarmUiModelsRepository, ) : FetchAndGetSwarmUiModelsUseCase { - override fun invoke(): Single>> = repository + override fun invoke(): Single> = repository .fetchAndGetModels() .map { models -> - if (!models.map(SwarmUiModel::name).contains(preferenceManager.swarmModel)) { - preferenceManager.swarmModel = models.firstOrNull()?.name ?: "" - } - models.map { model -> - model to (preferenceManager.swarmModel == model.name) + if (!models.map(SwarmUiModel::name).contains(preferenceManager.swarmUiModel)) { + preferenceManager.swarmUiModel = models.firstOrNull()?.name ?: "" } + models } } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt index 372373b5..7b422be1 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt @@ -2,10 +2,10 @@ package com.shifthackz.aisdv1.domain.usecase.caching import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository import io.reactivex.rxjava3.core.Completable @@ -16,7 +16,7 @@ class DataPreLoaderUseCaseImplTest { private val stubServerConfigurationRepository = mock() private val stubStableDiffusionModelsRepository = mock() private val stubStableDiffusionSamplersRepository = mock() - private val stubStableDiffusionLorasRepository = mock() + private val stubLorasRepository = mock() private val stubStableDiffusionHyperNetworksRepository = mock() private val stubStableDiffusionEmbeddingsRepository = mock() @@ -24,7 +24,7 @@ class DataPreLoaderUseCaseImplTest { serverConfigurationRepository = stubServerConfigurationRepository, sdModelsRepository = stubStableDiffusionModelsRepository, sdSamplersRepository = stubStableDiffusionSamplersRepository, - sdLorasRepository = stubStableDiffusionLorasRepository, + sdLorasRepository = stubLorasRepository, sdHyperNetworksRepository = stubStableDiffusionHyperNetworksRepository, sdEmbeddingsRepository = stubStableDiffusionEmbeddingsRepository, ) @@ -40,7 +40,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) @@ -69,7 +69,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) @@ -98,7 +98,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) @@ -127,7 +127,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.error(stubException)) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) @@ -156,7 +156,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.error(stubException)) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) @@ -185,7 +185,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) @@ -214,7 +214,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index c23282ba..b49d7b08 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -25,7 +25,7 @@ class GetConfigurationUseCaseImplTest { } returns AuthorizationCredentials.None every { - stubPreferenceManager::serverUrl.get() + stubPreferenceManager::automatic1111serverUrl.get() } returns mockConfiguration.serverUrl every { diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index fabc2fe4..0bf8b0cf 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -28,7 +28,7 @@ class SetServerConfigurationUseCaseImplTest { } returns Unit every { - stubPreferenceManager::serverUrl.set(any()) + stubPreferenceManager::automatic1111serverUrl.set(any()) } returns Unit every { diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt index c2787eeb..7419e773 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt @@ -28,7 +28,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111serverUrl) .thenReturn("") whenever(stubPreferenceManager.source) @@ -45,7 +45,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111serverUrl) .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) @@ -62,7 +62,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111serverUrl) .thenReturn("") whenever(stubPreferenceManager.source) @@ -79,7 +79,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111serverUrl) .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt index 80a75bec..7bd6f84c 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt @@ -9,6 +9,8 @@ data class SwarmUiGenerationRequest( val model: String, @SerializedName("initimage") val initImage: String?, + @SerializedName("initimagecreativity") + val initImageCreativity: String?, @SerializedName("images") val images: Int, @SerializedName("prompt") @@ -29,11 +31,9 @@ data class SwarmUiGenerationRequest( val cfgScale: Float?, @SerializedName("steps") val steps: Int, - @SerializedName("initimagecreativity") - val initimagecreativity: String = "0.6", - @SerializedName("initimageresettonorm") - val initimageresettonorm: String = "0", - @SerializedName("initimagerecompositemask") - val initimagerecompositemask: String = "0", +// @SerializedName("initimageresettonorm") +// val initimageresettonorm: String = "0", +// @SerializedName("initimagerecompositemask") +// val initimagerecompositemask: String = "0", ) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt index 83a98aa3..35461f67 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt @@ -5,6 +5,8 @@ import com.google.gson.annotations.SerializedName data class SwarmUiModelsRequest( @SerializedName("session_id") val sessionId: String, + @SerializedName("subtype") + val subType: String, @SerializedName("path") val path: String, @SerializedName("depth") diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt index 5ae52cee..2d669744 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt @@ -204,6 +204,7 @@ private fun ExtrasEmptyState(type: ExtraType) { .align(Alignment.CenterHorizontally), text = stringResource( id = when (type) { + //ToDo change empty state path depending on A1111/SWARM provider ExtraType.Lora -> R.string.extras_empty_sub_title_lora ExtraType.HyperNet -> R.string.extras_empty_sub_title_hypernet } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt index d5bcd1d8..53e53670 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt @@ -5,8 +5,8 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase import com.shifthackz.aisdv1.presentation.model.ErrorState @@ -72,14 +72,14 @@ class ExtrasViewModel( val (isApplied, value) = ExtrasFormatter.isExtraWithValuePresentInPrompt( prompt = prompt, loraAlias = when (it) { - is StableDiffusionLora -> it.alias + is LoRA -> it.alias is StableDiffusionHyperNetwork -> it.name else -> "" }, type = type, ) when (it) { - is StableDiffusionLora -> ExtraItemUi( + is LoRA -> ExtraItemUi( type = type, key = "${it.name}_${type}_${timeProvider.nanoTime()}", name = it.name, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index a541754f..4b097c13 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -27,7 +27,7 @@ class EngineSelectionViewModel( private val getConfigurationUseCase: GetConfigurationUseCase, private val selectStableDiffusionModelUseCase: SelectStableDiffusionModelUseCase, private val getStableDiffusionModelsUseCase: GetStableDiffusionModelsUseCase, - private val fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase, + fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase, observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, @@ -82,8 +82,8 @@ class EngineSelectionViewModel( sdModels = sdModels.map { it.first.title }, selectedSdModel = sdModels.firstOrNull { it.second }?.first?.title ?: state.selectedSdModel, - swarmModels = swarmModels.map { it.first.name }, - selectedSwarmModel = swarmModels.firstOrNull { it.second }?.first?.name + swarmModels = swarmModels.map { it.name }, + selectedSwarmModel = swarmModels.firstOrNull { it.name == config.swarmUiModel }?.name ?: state.selectedSwarmModel, hfModels = hfModels.map { it.alias }, selectedHfModel = config.huggingFaceModel, @@ -121,7 +121,7 @@ class EngineSelectionViewModel( } } - ServerSource.SWARM_UI -> preferenceManager.swarmModel = intent.value + ServerSource.SWARM_UI -> preferenceManager.swarmUiModel = intent.value ServerSource.HUGGING_FACE -> preferenceManager.huggingFaceModel = intent.value diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index f82e6ce3..59bb7600 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -529,6 +529,7 @@ fun GenerationInputForm( when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.STABILITY_AI, ServerSource.HORDE -> afterSlidersSection() @@ -564,7 +565,7 @@ fun GenerationInputForm( } } -fun processTaggedPrompt(keywords: List, event: ChipTextFieldEvent): String { +private fun processTaggedPrompt(keywords: List, event: ChipTextFieldEvent): String { val newKeywords = when (event) { is ChipTextFieldEvent.Add -> buildList { addAll(keywords) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt index 83d3fbe4..9ee88d63 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt @@ -49,26 +49,30 @@ fun GenerationBottomToolbar( .padding(top = 8.dp), contentAlignment = Alignment.BottomCenter, ) { - if (state.mode == ServerSource.AUTOMATIC1111) { - GenerationBottomToolbarBottomLayer( - modifier = Modifier.padding(bottom = 36.dp), - strokeAccentState = strokeAccentState, - state = state, - processIntent = processIntent, - ) - Box( - modifier = Modifier - .fillMaxWidth() - .height(60.dp) - .padding(horizontal = 16.dp) - .clip( - RoundedCornerShape( - topStart = 22.dp, - topEnd = 22.dp, + when (state.mode) { + ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI -> { + GenerationBottomToolbarBottomLayer( + modifier = Modifier.padding(bottom = 36.dp), + strokeAccentState = strokeAccentState, + state = state, + processIntent = processIntent, + ) + Box( + modifier = Modifier + .fillMaxWidth() + .height(60.dp) + .padding(horizontal = 16.dp) + .clip( + RoundedCornerShape( + topStart = 22.dp, + topEnd = 22.dp, + ) ) - ) - .background(color = MaterialTheme.colorScheme.surface), - ) + .background(color = MaterialTheme.colorScheme.surface), + ) + } + else -> Unit } content() } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt index dd20509a..8cadd972 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt @@ -1,14 +1,14 @@ package com.shifthackz.aisdv1.presentation.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA val mockStableDiffusionLoras = listOf( - StableDiffusionLora( + LoRA( name = "name_5598", alias = "alias_5598", path = "/unknown", ), - StableDiffusionLora( + LoRA( name = "name_151297", alias = "alias_151297", path = "/unknown", From 01c7d2ba873c1e5a38e1a91366e0c2212ce13ec8 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sun, 4 Aug 2024 23:01:43 +0300 Subject: [PATCH 04/11] Swarm UI Embeddings implementation --- .../aisdv1/data/di/LocalDataSourceModule.kt | 6 +-- .../aisdv1/data/di/RemoteDataSourceModule.kt | 6 ++- .../aisdv1/data/di/RepositoryModule.kt | 6 +-- ...Source.kt => EmbeddingsLocalDataSource.kt} | 10 ++--- .../StableDiffusionEmbeddingsMappers.kt | 18 ++++----- .../data/mappers/SwarmUiModelsMappers.kt | 13 +++++++ ...ableDiffusionEmbeddingsRemoteDataSource.kt | 4 +- .../SwarmUiEmbeddingsRemoteDataSource.kt | 29 ++++++++++++++ .../repository/EmbeddingsRepositoryImpl.kt | 38 +++++++++++++++++++ ...StableDiffusionEmbeddingsRepositoryImpl.kt | 20 ---------- ...st.kt => EmbeddingsLocalDataSourceTest.kt} | 14 +++---- .../mocks/StableDiffusionEmbeddingMocks.kt | 8 ++-- ...DiffusionEmbeddingsRemoteDataSourceTest.kt | 4 +- ...est.kt => EmbeddingsRepositoryImplTest.kt} | 34 ++++++++--------- .../domain/datasource/EmbeddingsDataSource.kt | 24 ++++++++++++ .../StableDiffusionEmbeddingsDataSource.kt | 17 --------- ...ableDiffusionEmbedding.kt => Embedding.kt} | 2 +- .../domain/repository/EmbeddingsRepository.kt | 11 ++++++ .../StableDiffusionEmbeddingsRepository.kt | 11 ------ .../caching/DataPreLoaderUseCaseImpl.kt | 4 +- .../FetchAndGetEmbeddingsUseCase.kt | 4 +- .../FetchAndGetEmbeddingsUseCaseImpl.kt | 8 ++-- .../mocks/StableDiffusionEmbeddingMocks.kt | 10 ++--- .../caching/DataPreLoaderUseCaseImplTest.kt | 20 +++++----- .../FetchAndGetEmbeddingsUseCaseImplTest.kt | 10 ++--- .../mocks/StableDiffusionEmbeddingMocks.kt | 8 ++-- .../modal/embedding/EmbeddingViewModelTest.kt | 6 +-- 27 files changed, 208 insertions(+), 137 deletions(-) rename data/src/main/java/com/shifthackz/aisdv1/data/local/{StableDiffusionEmbeddingsLocalDataSource.kt => EmbeddingsLocalDataSource.kt} (63%) create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt delete mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt rename data/src/test/java/com/shifthackz/aisdv1/data/local/{StableDiffusionEmbeddingsLocalDataSourceTest.kt => EmbeddingsLocalDataSourceTest.kt} (87%) rename data/src/test/java/com/shifthackz/aisdv1/data/repository/{StableDiffusionEmbeddingsRepositoryImplTest.kt => EmbeddingsRepositoryImplTest.kt} (81%) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt rename domain/src/main/java/com/shifthackz/aisdv1/domain/entity/{StableDiffusionEmbedding.kt => Embedding.kt} (66%) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt index 4373902e..460c294c 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt @@ -3,23 +3,23 @@ package com.shifthackz.aisdv1.data.di import com.shifthackz.aisdv1.data.gateway.DatabaseClearGatewayImpl import com.shifthackz.aisdv1.data.gateway.mediastore.MediaStoreGatewayFactory import com.shifthackz.aisdv1.data.local.DownloadableModelLocalDataSource +import com.shifthackz.aisdv1.data.local.EmbeddingsLocalDataSource import com.shifthackz.aisdv1.data.local.GenerationResultLocalDataSource import com.shifthackz.aisdv1.data.local.HuggingFaceModelsLocalDataSource import com.shifthackz.aisdv1.data.local.LorasLocalDataSource import com.shifthackz.aisdv1.data.local.ServerConfigurationLocalDataSource import com.shifthackz.aisdv1.data.local.StabilityAiCreditsLocalDataSource -import com.shifthackz.aisdv1.data.local.StableDiffusionEmbeddingsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionHyperNetworksLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionModelsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionSamplersLocalDataSource import com.shifthackz.aisdv1.data.local.SwarmUiModelsLocalDataSource import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource import com.shifthackz.aisdv1.domain.datasource.LorasDataSource import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource @@ -39,7 +39,7 @@ val localDataSourceModule = module { factoryOf(::StableDiffusionSamplersLocalDataSource) bind StableDiffusionSamplersDataSource.Local::class factoryOf(::LorasLocalDataSource) bind LorasDataSource.Local::class factoryOf(::StableDiffusionHyperNetworksLocalDataSource) bind StableDiffusionHyperNetworksDataSource.Local::class - factoryOf(::StableDiffusionEmbeddingsLocalDataSource) bind StableDiffusionEmbeddingsDataSource.Local::class + factoryOf(::EmbeddingsLocalDataSource) bind EmbeddingsDataSource.Local::class factoryOf(::SwarmUiModelsLocalDataSource) bind SwarmUiModelsDataSource.Local::class factoryOf(::ServerConfigurationLocalDataSource) bind ServerConfigurationDataSource.Local::class factoryOf(::GenerationResultLocalDataSource) bind GenerationResultDataSource.Local::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt index 25cdba92..40786808 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt @@ -20,11 +20,13 @@ import com.shifthackz.aisdv1.data.remote.StableDiffusionHyperNetworksRemoteDataS import com.shifthackz.aisdv1.data.remote.StableDiffusionLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionSamplersRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiEmbeddingsRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiGenerationRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.SwarmUiSessionDataSourceImpl import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource @@ -35,7 +37,6 @@ import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiEnginesDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiGenerationDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource @@ -76,12 +77,13 @@ val remoteDataSourceModule = module { factoryOf(::SwarmUiGenerationRemoteDataSource) bind SwarmUiGenerationDataSource.Remote::class factoryOf(::SwarmUiModelsRemoteDataSource) bind SwarmUiModelsDataSource.Remote::class factoryOf(::SwarmUiLorasRemoteDataSource) bind LorasDataSource.Remote.SwarmUi::class + factoryOf(::SwarmUiEmbeddingsRemoteDataSource) bind EmbeddingsDataSource.Remote.SwarmUi::class factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class factoryOf(::StableDiffusionLorasRemoteDataSource) bind LorasDataSource.Remote.Automatic1111::class factoryOf(::StableDiffusionHyperNetworksRemoteDataSource) bind StableDiffusionHyperNetworksDataSource.Remote::class - factoryOf(::StableDiffusionEmbeddingsRemoteDataSource) bind StableDiffusionEmbeddingsDataSource.Remote::class + factoryOf(::StableDiffusionEmbeddingsRemoteDataSource) bind EmbeddingsDataSource.Remote.Automatic1111::class factoryOf(::ServerConfigurationRemoteDataSource) bind ServerConfigurationDataSource.Remote::class factoryOf(::RandomImageRemoteDataSource) bind RandomImageDataSource.Remote::class factoryOf(::DownloadableModelRemoteDataSource) bind DownloadableModelDataSource.Remote::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index 3d48dd83..d8e40880 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.data.di import android.content.Context import android.os.PowerManager import com.shifthackz.aisdv1.data.repository.DownloadableModelRepositoryImpl +import com.shifthackz.aisdv1.data.repository.EmbeddingsRepositoryImpl import com.shifthackz.aisdv1.data.repository.GenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl @@ -15,7 +16,6 @@ import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl import com.shifthackz.aisdv1.data.repository.StabilityAiCreditsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StabilityAiEnginesRepositoryImpl import com.shifthackz.aisdv1.data.repository.StabilityAiGenerationRepositoryImpl -import com.shifthackz.aisdv1.data.repository.StableDiffusionEmbeddingsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionHyperNetworksRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionModelsRepositoryImpl @@ -25,6 +25,7 @@ import com.shifthackz.aisdv1.data.repository.SwarmUiModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.TemporaryGenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.WakeLockRepositoryImpl import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository @@ -37,7 +38,6 @@ import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiCreditsRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiEnginesRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository @@ -74,7 +74,7 @@ val repositoryModule = module { factoryOf(::StableDiffusionSamplersRepositoryImpl) bind StableDiffusionSamplersRepository::class factoryOf(::LorasRepositoryImpl) bind LorasRepository::class factoryOf(::StableDiffusionHyperNetworksRepositoryImpl) bind StableDiffusionHyperNetworksRepository::class - factoryOf(::StableDiffusionEmbeddingsRepositoryImpl) bind StableDiffusionEmbeddingsRepository::class + factoryOf(::EmbeddingsRepositoryImpl) bind EmbeddingsRepository::class factoryOf(::ServerConfigurationRepositoryImpl) bind ServerConfigurationRepository::class factoryOf(::GenerationResultRepositoryImpl) bind GenerationResultRepository::class factoryOf(::RandomImageRepositoryImpl) bind RandomImageRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSource.kt similarity index 63% rename from data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSource.kt rename to data/src/main/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSource.kt index f86b3d28..66db4d3c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSource.kt @@ -2,20 +2,20 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionEmbeddingDao import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity -internal class StableDiffusionEmbeddingsLocalDataSource( +internal class EmbeddingsLocalDataSource( private val dao: StableDiffusionEmbeddingDao, -) : StableDiffusionEmbeddingsDataSource.Local { +) : EmbeddingsDataSource.Local { override fun getEmbeddings() = dao .queryAll() .map(List::mapEntityToDomain) - override fun insertEmbeddings(list: List) = dao + override fun insertEmbeddings(list: List) = dao .deleteAll() .andThen(dao.insertList(list.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt index 216dc124..8403d436 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt @@ -1,20 +1,20 @@ package com.shifthackz.aisdv1.data.mappers -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity //region RAW -> DOMAIN -fun SdEmbeddingsResponse.mapRawToCheckpointDomain(): List = - loaded?.keys?.map(::StableDiffusionEmbedding) ?: emptyList() +fun SdEmbeddingsResponse.mapRawToCheckpointDomain(): List = + loaded?.keys?.map(::Embedding) ?: emptyList() //endregion //region DOMAIN -> ENTITY -fun List.mapDomainToEntity(): List = - map(StableDiffusionEmbedding::mapDomainToEntity) +fun List.mapDomainToEntity(): List = + map(Embedding::mapDomainToEntity) -fun StableDiffusionEmbedding.mapDomainToEntity(): StableDiffusionEmbeddingEntity = with(this) { +fun Embedding.mapDomainToEntity(): StableDiffusionEmbeddingEntity = with(this) { StableDiffusionEmbeddingEntity( id = keyword, keyword = keyword, @@ -23,10 +23,10 @@ fun StableDiffusionEmbedding.mapDomainToEntity(): StableDiffusionEmbeddingEntity //endregion //region ENTITY -> DOMAIN -fun List.mapEntityToDomain(): List = +fun List.mapEntityToDomain(): List = map(StableDiffusionEmbeddingEntity::mapEntityToDomain) -fun StableDiffusionEmbeddingEntity.mapEntityToDomain(): StableDiffusionEmbedding = with(this) { - StableDiffusionEmbedding(keyword) +fun StableDiffusionEmbeddingEntity.mapEntityToDomain(): Embedding = with(this) { + Embedding(keyword) } //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt index d3e5fd27..931fa4d0 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.data.mappers +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.domain.entity.SwarmUiModel import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw @@ -38,6 +39,18 @@ fun SwarmUiModelRaw.mapRawToLoraDomain(): LoRA = with(this) { } //endregion +//region RAW -> EMBEDDING DOMAIN +fun SwarmUiModelsResponse.mapRawToEmbeddingDomain(): List = with(this) { + this.files?.mapRawToEmbeddingDomain() ?: emptyList() +} + +fun List.mapRawToEmbeddingDomain(): List = map(SwarmUiModelRaw::mapRawToEmbeddingDomain) + +fun SwarmUiModelRaw.mapRawToEmbeddingDomain(): Embedding = with(this) { + Embedding(title ?: "") +} +//endregion + //region CHECKPOINT DOMAIN --> ENTITY fun List.mapDomainToEntity(): List = map(SwarmUiModel::mapDomainToEntity) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt index 4c211af3..ae45d962 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi.Companion.PATH_EMBEDDINGS import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse @@ -10,7 +10,7 @@ import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse internal class StableDiffusionEmbeddingsRemoteDataSource( private val serverUrlProvider: ServerUrlProvider, private val api: Automatic1111RestApi, -) : StableDiffusionEmbeddingsDataSource.Remote { +) : EmbeddingsDataSource.Remote.Automatic1111 { override fun fetchEmbeddings() = serverUrlProvider(PATH_EMBEDDINGS) .flatMap(api::fetchEmbeddings) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt new file mode 100644 index 00000000..6909ff53 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapRawToEmbeddingDomain +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_MODELS +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.reactivex.rxjava3.core.Single + +class SwarmUiEmbeddingsRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, +) : EmbeddingsDataSource.Remote.SwarmUi { + + override fun fetchEmbeddings(sessionId: String): Single> = serverUrlProvider(PATH_MODELS) + .flatMap { url -> + val request = SwarmUiModelsRequest( + sessionId = sessionId, + subType = "Embedding", + path = "", + depth = 3, + ) + api.fetchModels(url, request) + } + .map(SwarmUiModelsResponse::mapRawToEmbeddingDomain) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt new file mode 100644 index 00000000..6b7ef62a --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt @@ -0,0 +1,38 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class EmbeddingsRepositoryImpl( + private val rdsA1111: EmbeddingsDataSource.Remote.Automatic1111, + private val rdsSwarm: EmbeddingsDataSource.Remote.SwarmUi, + private val swarmSession: SwarmUiSessionDataSource, + private val lds: EmbeddingsDataSource.Local, + private val preferenceManager: PreferenceManager, +) : EmbeddingsRepository { + + override fun fetchEmbeddings(): Completable = when (preferenceManager.source) { + ServerSource.AUTOMATIC1111 -> rdsA1111 + .fetchEmbeddings() + .flatMapCompletable(lds::insertEmbeddings) + + ServerSource.SWARM_UI -> swarmSession + .getSessionId() + .flatMap(rdsSwarm::fetchEmbeddings) + .flatMapCompletable(lds::insertEmbeddings) + + else -> Completable.complete() + } + + override fun fetchAndGetEmbeddings(): Single> = fetchEmbeddings() + .onErrorComplete() + .andThen(lds.getEmbeddings()) + + override fun getEmbeddings(): Single> = lds.getEmbeddings() +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt deleted file mode 100644 index 1f0dc384..00000000 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt +++ /dev/null @@ -1,20 +0,0 @@ -package com.shifthackz.aisdv1.data.repository - -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository - -internal class StableDiffusionEmbeddingsRepositoryImpl( - private val remoteDataSource: StableDiffusionEmbeddingsDataSource.Remote, - private val localDataSource: StableDiffusionEmbeddingsDataSource.Local, -) : StableDiffusionEmbeddingsRepository { - - override fun fetchEmbeddings() = remoteDataSource - .fetchEmbeddings() - .flatMapCompletable(localDataSource::insertEmbeddings) - - override fun fetchAndGetEmbeddings() = fetchEmbeddings() - .onErrorComplete() - .andThen(localDataSource.getEmbeddings()) - - override fun getEmbeddings() = localDataSource.getEmbeddings() -} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSourceTest.kt similarity index 87% rename from data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSourceTest.kt index 09cc3b27..f5a8c652 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSourceTest.kt @@ -1,7 +1,7 @@ package com.shifthackz.aisdv1.data.local +import com.shifthackz.aisdv1.data.mocks.mockEmbeddings import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddingEntities -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddings import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionEmbeddingDao import io.mockk.every import io.mockk.mockk @@ -9,12 +9,12 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class StableDiffusionEmbeddingsLocalDataSourceTest { +class EmbeddingsLocalDataSourceTest { private val stubException = Throwable("Database error.") private val stubDao = mockk() - private val localDataSource = StableDiffusionEmbeddingsLocalDataSource(stubDao) + private val localDataSource = EmbeddingsLocalDataSource(stubDao) @Test fun `given attempt to get embeddings, dao returns list, expected valid domain model list value`() { @@ -26,7 +26,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { .getEmbeddings() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } @@ -72,7 +72,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { } returns Completable.complete() localDataSource - .insertEmbeddings(mockStableDiffusionEmbeddings) + .insertEmbeddings(mockEmbeddings) .test() .assertNoErrors() .await() @@ -90,7 +90,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { } returns Completable.complete() localDataSource - .insertEmbeddings(mockStableDiffusionEmbeddings) + .insertEmbeddings(mockEmbeddings) .test() .assertError(stubException) .await() @@ -108,7 +108,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { } returns Completable.error(stubException) localDataSource - .insertEmbeddings(mockStableDiffusionEmbeddings) + .insertEmbeddings(mockEmbeddings) .test() .assertError(stubException) .await() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt index a7286a4e..c4a5e9d2 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.data.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding -val mockStableDiffusionEmbeddings = listOf( - StableDiffusionEmbedding("keyword_5598"), - StableDiffusionEmbedding("keyword_151297"), +val mockEmbeddings = listOf( + Embedding("keyword_5598"), + Embedding("keyword_151297"), ) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt index bc31e80f..94528711 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt @@ -3,7 +3,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mocks.mockEmptySdEmbeddingsResponse import com.shifthackz.aisdv1.data.mocks.mockSdEmbeddingsResponse import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import io.mockk.every import io.mockk.mockk @@ -39,7 +39,7 @@ class StableDiffusionEmbeddingsRemoteDataSourceTest { .fetchEmbeddings() .test() .assertNoErrors() - .assertValue(listOf(StableDiffusionEmbedding("1504"))) + .assertValue(listOf(Embedding("1504"))) .await() .assertComplete() } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt similarity index 81% rename from data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt index 933c19c3..acb46779 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt @@ -1,29 +1,29 @@ package com.shifthackz.aisdv1.data.repository -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddings -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource +import com.shifthackz.aisdv1.data.mocks.mockEmbeddings +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import io.mockk.every import io.mockk.mockk import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class StableDiffusionEmbeddingsRepositoryImplTest { +class EmbeddingsRepositoryImplTest { private val stubException = Throwable("Something went wrong.") - private val stubRemoteDataSource = mockk() - private val stubLocalDataSource = mockk() + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() - private val repository = StableDiffusionEmbeddingsRepositoryImpl( - remoteDataSource = stubRemoteDataSource, - localDataSource = stubLocalDataSource, + private val repository = EmbeddingsRepositoryImpl( + rdsA1111 = stubRemoteDataSource, + lds = stubLocalDataSource, ) @Test fun `given attempt to fetch embeddings, remote returns data, local insert success, expected complete value`() { every { stubRemoteDataSource.fetchEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) every { stubLocalDataSource.insertEmbeddings(any()) @@ -59,7 +59,7 @@ class StableDiffusionEmbeddingsRepositoryImplTest { fun `given attempt to fetch embeddings, remote returns data, local insert fails, expected error value`() { every { stubRemoteDataSource.fetchEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) every { stubLocalDataSource.insertEmbeddings(any()) @@ -77,13 +77,13 @@ class StableDiffusionEmbeddingsRepositoryImplTest { fun `given attempt to get embeddings, local data source returns list, expected valid domain models list value`() { every { stubLocalDataSource.getEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) repository .getEmbeddings() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } @@ -122,7 +122,7 @@ class StableDiffusionEmbeddingsRepositoryImplTest { fun `given attempt to fetch and get embeddings, remote returns data, local returns data, expected valid domain models list value`() { every { stubRemoteDataSource.fetchEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) every { stubLocalDataSource.insertEmbeddings(any()) @@ -130,13 +130,13 @@ class StableDiffusionEmbeddingsRepositoryImplTest { every { stubLocalDataSource.getEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) repository .fetchAndGetEmbeddings() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } @@ -153,13 +153,13 @@ class StableDiffusionEmbeddingsRepositoryImplTest { every { stubLocalDataSource.getEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) repository .fetchAndGetEmbeddings() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt new file mode 100644 index 00000000..cadd6b43 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.Embedding +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +sealed interface EmbeddingsDataSource { + + interface Remote : EmbeddingsDataSource { + + interface Automatic1111 : Remote { + fun fetchEmbeddings(): Single> + } + + interface SwarmUi : Remote { + fun fetchEmbeddings(sessionId: String): Single> + } + } + + interface Local : EmbeddingsDataSource { + fun getEmbeddings(): Single> + fun insertEmbeddings(list: List): Completable + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt deleted file mode 100644 index a1ed1da8..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt +++ /dev/null @@ -1,17 +0,0 @@ -package com.shifthackz.aisdv1.domain.datasource - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -sealed interface StableDiffusionEmbeddingsDataSource { - - interface Remote : StableDiffusionEmbeddingsDataSource { - fun fetchEmbeddings(): Single> - } - - interface Local : StableDiffusionEmbeddingsDataSource { - fun getEmbeddings(): Single> - fun insertEmbeddings(list: List): Completable - } -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionEmbedding.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Embedding.kt similarity index 66% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionEmbedding.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Embedding.kt index a4faccc8..6e2d1b90 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionEmbedding.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Embedding.kt @@ -1,5 +1,5 @@ package com.shifthackz.aisdv1.domain.entity -data class StableDiffusionEmbedding( +data class Embedding( val keyword: String, ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt new file mode 100644 index 00000000..e641225e --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.Embedding +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface EmbeddingsRepository { + fun fetchEmbeddings(): Completable + fun fetchAndGetEmbeddings(): Single> + fun getEmbeddings(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt deleted file mode 100644 index 57118629..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt +++ /dev/null @@ -1,11 +0,0 @@ -package com.shifthackz.aisdv1.domain.repository - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -interface StableDiffusionEmbeddingsRepository { - fun fetchEmbeddings(): Completable - fun fetchAndGetEmbeddings(): Single> - fun getEmbeddings(): Single> -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt index a711718c..84a5ffe8 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.caching +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository @@ -13,7 +13,7 @@ internal class DataPreLoaderUseCaseImpl( private val sdSamplersRepository: StableDiffusionSamplersRepository, private val sdLorasRepository: LorasRepository, private val sdHyperNetworksRepository: StableDiffusionHyperNetworksRepository, - private val sdEmbeddingsRepository: StableDiffusionEmbeddingsRepository, + private val sdEmbeddingsRepository: EmbeddingsRepository, ) : DataPreLoaderUseCase { override operator fun invoke() = serverConfigurationRepository diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt index 86f148ad..ca731928 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.sdembedding -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding import io.reactivex.rxjava3.core.Single interface FetchAndGetEmbeddingsUseCase { - operator fun invoke(): Single> + operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt index 1f6d5c22..257e37e0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt @@ -1,10 +1,12 @@ package com.shifthackz.aisdv1.domain.usecase.sdembedding -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository +import io.reactivex.rxjava3.core.Single internal class FetchAndGetEmbeddingsUseCaseImpl( - private val repository: StableDiffusionEmbeddingsRepository, + private val repository: EmbeddingsRepository, ) : FetchAndGetEmbeddingsUseCase { - override fun invoke() = repository.fetchAndGetEmbeddings() + override fun invoke(): Single> = repository.fetchAndGetEmbeddings() } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt index 8cf54d7e..3a44c632 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt @@ -1,9 +1,9 @@ package com.shifthackz.aisdv1.domain.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding -val mockStableDiffusionEmbeddings = listOf( - StableDiffusionEmbedding("embedding_1"), - StableDiffusionEmbedding("embedding_2"), - StableDiffusionEmbedding("embedding_3"), +val mockEmbeddings = listOf( + Embedding("embedding_1"), + Embedding("embedding_2"), + Embedding("embedding_3"), ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt index 7b422be1..9a920900 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt @@ -2,9 +2,9 @@ package com.shifthackz.aisdv1.domain.usecase.caching import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository @@ -18,7 +18,7 @@ class DataPreLoaderUseCaseImplTest { private val stubStableDiffusionSamplersRepository = mock() private val stubLorasRepository = mock() private val stubStableDiffusionHyperNetworksRepository = mock() - private val stubStableDiffusionEmbeddingsRepository = mock() + private val stubEmbeddingsRepository = mock() private val useCase = DataPreLoaderUseCaseImpl( serverConfigurationRepository = stubServerConfigurationRepository, @@ -26,7 +26,7 @@ class DataPreLoaderUseCaseImplTest { sdSamplersRepository = stubStableDiffusionSamplersRepository, sdLorasRepository = stubLorasRepository, sdHyperNetworksRepository = stubStableDiffusionHyperNetworksRepository, - sdEmbeddingsRepository = stubStableDiffusionEmbeddingsRepository, + sdEmbeddingsRepository = stubEmbeddingsRepository, ) @Test @@ -46,7 +46,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -75,7 +75,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -104,7 +104,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -133,7 +133,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -162,7 +162,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -191,7 +191,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.error(stubException)) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -220,7 +220,7 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt index 28952d80..bc806e01 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt @@ -3,26 +3,26 @@ package com.shifthackz.aisdv1.domain.usecase.sdembedding import com.nhaarman.mockitokotlin2.doReturn import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever -import com.shifthackz.aisdv1.domain.mocks.mockStableDiffusionEmbeddings -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository +import com.shifthackz.aisdv1.domain.mocks.mockEmbeddings +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import io.reactivex.rxjava3.core.Single import org.junit.Test class FetchAndGetEmbeddingsUseCaseImplTest { - private val stubRepository = mock() + private val stubRepository = mock() private val useCase = FetchAndGetEmbeddingsUseCaseImpl(stubRepository) @Test fun `given repository provided embeddings list, expected valid list value`() { whenever(stubRepository.fetchAndGetEmbeddings()) - .doReturn(Single.just(mockStableDiffusionEmbeddings)) + .doReturn(Single.just(mockEmbeddings)) useCase() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt index 5c83e46e..5b8f9b45 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.presentation.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding -val mockStableDiffusionEmbeddings = listOf( - StableDiffusionEmbedding("5598"), - StableDiffusionEmbedding("151297"), +val mockEmbeddings = listOf( + Embedding("5598"), + Embedding("151297"), ) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt index 0d5535c8..78d27d5b 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.presentation.modal.embedding import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest -import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionEmbeddings +import com.shifthackz.aisdv1.presentation.mocks.mockEmbeddings import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider @@ -28,7 +28,7 @@ class EmbeddingViewModelTest : CoreViewModelTest() { fun `given update data, fetch embeddings successful, expected UI state with embeddings list`() { every { stubFetchAndGetEmbeddingsUseCase() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) viewModel.updateData("prompt", "negative") @@ -106,7 +106,7 @@ class EmbeddingViewModelTest : CoreViewModelTest() { fun `given received ToggleItem intent, expected item from intent isInNegativePrompt changed in UI state`() { every { stubFetchAndGetEmbeddingsUseCase() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) viewModel.updateData("prompt", "negative") From e4f09ab46df38bd9151545c7f25143bb46cb41f2 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sun, 4 Aug 2024 23:23:03 +0300 Subject: [PATCH 05/11] Fixed GenerationBottomToolbar --- .../widget/toolbar/GenearionBottomToolbar.kt | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt index 9ee88d63..d46a6ace 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt @@ -163,28 +163,30 @@ private fun GenerationBottomToolbarBottomLayer( color = localColor, style = localStyle, ) - Spacer( - modifier = Modifier - .width(1.dp) - .height(with(LocalDensity.current) { dividerHeight.toDp() }) - .background(color = accentColor), - ) - Text( - modifier = localModifier { - processIntent( - GenerationMviIntent.SetModal( - Modal.ExtraBottomSheet( - state.prompt, - state.negativePrompt, - ExtraType.HyperNet, + if (state.mode == ServerSource.AUTOMATIC1111) { + Spacer( + modifier = Modifier + .width(1.dp) + .height(with(LocalDensity.current) { dividerHeight.toDp() }) + .background(color = accentColor), + ) + Text( + modifier = localModifier { + processIntent( + GenerationMviIntent.SetModal( + Modal.ExtraBottomSheet( + state.prompt, + state.negativePrompt, + ExtraType.HyperNet, + ), ), - ), - ) - }, - text = stringResource(id = R.string.title_hyper_net_short), - textAlign = TextAlign.Center, - color = localColor, - style = localStyle, - ) + ) + }, + text = stringResource(id = R.string.title_hyper_net_short), + textAlign = TextAlign.Center, + color = localColor, + style = localStyle, + ) + } } } From eed94d8c6700411a98cc2a67bc15520b54623ae1 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 5 Aug 2024 10:19:31 +0300 Subject: [PATCH 06/11] Update strings / docs --- README.md | 19 +++++++++++----- app/build.gradle | 1 + .../aisdv1/app/di/ProvidersModule.kt | 1 + .../aisdv1/core/common/links/LinksProvider.kt | 1 + .../aisdv1/domain/entity/ServerSource.kt | 7 +++--- .../modal/embedding/EmbeddingViewModel.kt | 2 -- .../screen/setup/ServerSetupIntent.kt | 4 ++++ .../components/ConfigurationModeButton.kt | 3 ++- .../screen/setup/forms/Automatic1111Form.kt | 9 ++++---- .../screen/setup/forms/SwarmUiForm.kt | 22 +++++++++++++++++++ presentation/src/main/res/values/strings.xml | 8 ++++--- 11 files changed, 58 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 8c6bce68..5350358a 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Stable Diffusion AI is an easy-to-use app that lets you quickly generate images - Can use server environment powered by [AI Horde](https://stablehorde.net/) (a crowdsourced distributed cluster of Stable Diffusion workers) - Can use server environment powered by [Stable-Diffusion-WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (AUTOMATIC1111) +- Can use server environment powered by [SwarmUI](https://github.com/mcmonkeyprojects/SwarmUI) - Can use server envitonment powered by [Hugging Face Inference API](https://huggingface.co/docs/api-inference/quicktour). - Can use server environment powered by [OpenAI](https://platform.openai.com/docs/api-reference/images) (DALL-E-2, DALL-E-3). - Can use server environment powered by [Stability AI](https://platform.stability.ai/). @@ -70,31 +71,39 @@ You can have it running either on your own hardware with modern GPU from Nvidia If for some reason you have no ability to run your server instance, you can toggle the **Demo mode** switch on server setup page: it will allow you to test the app and get familiar with it, but it will return some mock images instead of AI-generated ones. -### Option 2: Use AI Horde +### Option 2: Use your own SwarmUI instance + +This requires you to have the SwarmUI that is running in server mode. + +You can have it running either on your own hardware with modern GPU from Nvidia or AMD, or running it using Google Colab. + +Please refer to the [SwarmUI documentation](https://github.com/mcmonkeyprojects/SwarmUI?tab=readme-ov-file#swarmui) for installation instructions. + +### Option 3: Use AI Horde [AI Horde](https://stablehorde.net/) is a crowdsourced distributed cluster of Image generation workers and text generation workers. AI Horde requires to use API KEY, this mobile app alows to use either default API KEY (which is "0000000000"), or type your own. You can sign up and get your own AI Horde API KEY [here](https://stablehorde.net/register). -### Option 3: Hugging Face Inference +### Option 4: Hugging Face Inference [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index) allows to test and evaluate, over 150,000 publicly accessible machine learning models, or your own private models, via simple HTTP requests, with fast inference hosted on Hugging Face shared infrastructure. This service is free, but is rate-limited. Hugging Face Inference requires to use API KEY, which can be created in [Hugging Face account settings](https://huggingface.co/settings/tokens). -### Option 4: OpenAI +### Option 5: OpenAI OpenAI provides a service for text to image generation using [DALLE-2](https://openai.com/dall-e-2) or [DALLE-3](https://openai.com/dall-e-3) models. This service is paid. OpenAI requires to use API KEY, which can be created in [OpenAI API Key settings](https://platform.openai.com/api-keys). -### Option 5: StabilityAI +### Option 6: StabilityAI [StabilityAI](https://platform.stability.ai/) is the image generation service provided by DreamStudio. StabilityAI requires to use API KEY, which can be created in [API Keys page](https://platform.stability.ai/account/keys). -### Option 6: Local Diffusion (Beta) +### Option 7: Local Diffusion (Beta) Only **txt2img** mode is supported. diff --git a/app/build.gradle b/app/build.gradle index bc3f0a73..8e0b2216 100755 --- a/app/build.gradle +++ b/app/build.gradle @@ -34,6 +34,7 @@ android { buildConfigField "String", "DONATE_URL", "\"https://www.buymeacoffee.com/shifthackz\"" buildConfigField "String", "GITHUB_SOURCE_URL", "\"https://github.com/ShiftHackZ/Stable-Diffusion-Android\"" buildConfigField "String", "SETUP_INSTRUCTIONS_URL", "\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki\"" + buildConfigField "String", "SWARM_UI_INFO_URL", "\"https://github.com/mcmonkeyprojects/SwarmUI/tree/master/docs\"" resourceConfigurations = ["en", "ru", "uk", "tr", "zh"] } diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 933011dc..6e43a084 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -107,6 +107,7 @@ val providersModule = module { override val donateUrl: String = BuildConfig.DONATE_URL override val gitHubSourceUrl: String = BuildConfig.GITHUB_SOURCE_URL override val setupInstructionsUrl: String = BuildConfig.SETUP_INSTRUCTIONS_URL + override val swarmUiInfoUrl: String = BuildConfig.SWARM_UI_INFO_URL override val demoModeUrl: String = BuildConfig.DEMO_MODE_API_URL } } diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt index bcf5aed0..ac6670be 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt @@ -10,5 +10,6 @@ interface LinksProvider { val donateUrl: String val gitHubSourceUrl: String val setupInstructionsUrl: String + val swarmUiInfoUrl: String val demoModeUrl: String } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index dae3ebc7..4eeca344 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -22,11 +22,10 @@ enum class ServerSource( featureTags = setOf( FeatureTag.Txt2Img, FeatureTag.OwnServer, -// FeatureTag.Img2Img, + FeatureTag.Img2Img, FeatureTag.MultipleModels, -// FeatureTag.Lora, -// FeatureTag.TextualInversion, -// FeatureTag.HyperNetworks, + FeatureTag.Lora, + FeatureTag.TextualInversion, FeatureTag.Batch, ), ), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt index 2b284aee..4bc3eec7 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt @@ -1,6 +1,5 @@ package com.shifthackz.aisdv1.presentation.modal.embedding -import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread @@ -61,7 +60,6 @@ class EmbeddingViewModel( updateState { it.copy(loading = false, error = ErrorState.Generic) } }, onSuccess = { embeddings -> - debugLog(embeddings) updateState { state -> state.copy( loading = false, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt index e64c1980..80d51618 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt @@ -58,7 +58,11 @@ sealed interface ServerSetupIntent : MviIntent { data object A1111Instructions : LaunchUrl() { override val url: String get() = linksProvider.setupInstructionsUrl + } + data object SwarmUiInstructions : LaunchUrl() { + override val url: String + get() = linksProvider.swarmUiInfoUrl } data object HordeInfo : LaunchUrl() { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index 0ace2923..5643baec 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -94,13 +94,14 @@ fun ConfigurationModeButton( ) } val descriptionId = when (mode) { - ServerSource.AUTOMATIC1111 -> null + ServerSource.AUTOMATIC1111 -> R.string.hint_server_setup_sub_title ServerSource.HORDE -> R.string.hint_server_horde_sub_title ServerSource.HUGGING_FACE -> R.string.hint_hugging_face_sub_title ServerSource.OPEN_AI -> R.string.hint_open_ai_sub_title ServerSource.LOCAL -> R.string.hint_local_diffusion_sub_title ServerSource.STABILITY_AI -> R.string.hint_stability_ai_sub_title ServerSource.SWARM_UI -> R.string.hint_swarm_ui_sub_title + else -> null } descriptionId?.let { resId -> Text( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt index 8211b076..1a60852b 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt @@ -98,10 +98,11 @@ fun Automatic1111Form( ) Text( modifier = Modifier.padding(top = 8.dp, bottom = 16.dp), - text = stringResource( - if (state.demoMode) R.string.hint_demo_mode - else R.string.hint_valid_urls, - ), + text = if (state.demoMode) { + stringResource(R.string.hint_demo_mode) + } else { + stringResource(R.string.hint_valid_urls, "7860") + }, style = MaterialTheme.typography.bodyMedium, ) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt index e10fe353..a70cf7b5 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt @@ -3,6 +3,8 @@ package com.shifthackz.aisdv1.presentation.screen.setup.forms import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.filled.Help import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.material3.TextField @@ -13,9 +15,11 @@ import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState +import com.shifthackz.aisdv1.presentation.widget.item.SettingsItem @Composable fun SwarmUiForm( @@ -56,5 +60,23 @@ fun SwarmUiForm( processIntent = processIntent, modifier = fieldModifier, ) + SettingsItem( + modifier = Modifier + .padding(top = 16.dp) + .fillMaxWidth(), + startIcon = Icons.AutoMirrored.Filled.Help, + text = R.string.settings_item_instructions.asUiText(), + onClick = { processIntent(ServerSetupIntent.LaunchUrl.SwarmUiInstructions) }, + ) + Text( + modifier = Modifier.padding(top = 8.dp), + text = stringResource(id = R.string.hint_args_swarm_ui_warning), + style = MaterialTheme.typography.bodyMedium, + ) + Text( + modifier = Modifier.padding(top = 8.dp, bottom = 16.dp), + text = stringResource(R.string.hint_valid_urls, "7801"), + style = MaterialTheme.typography.bodyMedium, + ) } } diff --git a/presentation/src/main/res/values/strings.xml b/presentation/src/main/res/values/strings.xml index c4a9210d..3c4d3fe4 100755 --- a/presentation/src/main/res/values/strings.xml +++ b/presentation/src/main/res/values/strings.xml @@ -112,9 +112,11 @@ HTTP Basic Provide your Stable Diffusion WebUI URL - Here are the examples of server URLs:\n• http://192.168.0.2:7860\n• http://yourdomain.com:7860\n• https://yourdomain.com + A web interface for Stable Diffusion, implemented using Gradio library. + Here are the examples of server URLs:\n• http://192.168.0.2:%1$s\n• http://yourdomain.com:%1$s\n• https://yourdomain.com This mode allows you to test the application behavior, even if you don\'t have Stable Diffusion WebUI server.\n\nIn demo mode app ignores user prompt, does not use AI server, and returns some mock images. Before connecting ensure that:\n• you are running AUTOMATIC1111 WebUI with flags --api --listen\n• your firewall is not blocking 7860 port\n• phone is on the same WiFi with your PC + Before connecting ensure that:\n• your firewall is not blocking 7801 port\n• phone is on the same WiFi with your PC Connect to Horde AI cloud Horde AI is a crowdsourced distributed cluster of Image generation workers and text generation workers. @@ -138,8 +140,8 @@ About Stability AI Stability AI Engine - Connect to Swarm UI - Swarm UI ... + Provide your Swarm UI URL + A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility. Local Diffusion This configuration allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. From f05af9b96bbe789d678476ba6cda478d907f98a6 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 5 Aug 2024 10:33:47 +0300 Subject: [PATCH 07/11] Add translations --- presentation/src/main/res/values-ru/strings.xml | 6 ++++++ presentation/src/main/res/values-tr/strings.xml | 10 ++++++++-- presentation/src/main/res/values-uk/strings.xml | 6 ++++++ presentation/src/main/res/values-zh/strings.xml | 7 +++++++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/presentation/src/main/res/values-ru/strings.xml b/presentation/src/main/res/values-ru/strings.xml index 17830fbf..d1b243c1 100644 --- a/presentation/src/main/res/values-ru/strings.xml +++ b/presentation/src/main/res/values-ru/strings.xml @@ -62,6 +62,7 @@ Пакетная генерация Поддержка моделей Генерация без интернета + Собственный сервер CFG Шкала: %1$s Метод выборки Стиль @@ -94,9 +95,11 @@ HTTP Базовая Укажите свой URL-адрес Stable Diffusion WebUI + Веб-интерфейс для Stable Diffusion, реализованный с использованием библиотеки Gradio. Примеры URL-адресов сервера:\nhttp://192.168.0.2:7860\nhttp://yourdomain.com:7860\nhttps://yourdomain.com Этот режим позволяет проверить поведение программы, даже если у вас нет сервера Stable Diffusion WebUI.\n\nВ демонстрационном режиме программа игнорирует параметры генерации, не использует сервер искусственного интеллекта и возвращает фиктивные изображения. Перед подключением убедитесь, что:\n• вы используете AUTOMATIC1111 WebUI с аргументами --api --listen\n• ваш брандмауэр не блокирует порт 7860\n• телефон подключен к одной сети Wi-Fi с ПК + Перед подключением убедитесь, что:\n• ваш брандмауэр не блокирует порт 7860\n• телефон подключен к одной сети Wi-Fi с ПК Подключиться к Horde AI Horde AI – это краудсорсинговый распределенный кластер нод генерации изображений и текста. @@ -120,6 +123,9 @@ О Stability AI Stability AI движок + Укажите свйой URL-адрес Swarm UI + Модульный веб-интерфейс Stable Diffusion, в котором особое внимание уделяется обеспечению легкого доступа к инструментам, высокой производительности и расширяемости. + Эта конфигурация позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. ВНИМАНИЕ! Функциональность Local Diffusion в бета-тестировании. Не ожидайте высококачественных изображений в локальном режиме. \n\nЭта реализация может не работать должным образом на мобильных телефонах. Производительность и скорость генерации зависят от ресурсов вашего телефона (ЦП, ОЗУ) и размера сгенерированного изображения (чем меньше размер изображения, тем быстрее генерируется). diff --git a/presentation/src/main/res/values-tr/strings.xml b/presentation/src/main/res/values-tr/strings.xml index 9ebc6821..86cf57bf 100644 --- a/presentation/src/main/res/values-tr/strings.xml +++ b/presentation/src/main/res/values-tr/strings.xml @@ -62,6 +62,7 @@ Toplu üretim Çoklu Modeller Çevrimdışı nesil + Kendi Sunucumuz CFG Scale: %1$s Örneklerme Yöntemi Stil ön ayarı @@ -94,9 +95,11 @@ HTTP Temel Lütfen Stable Diffusion WebUI(AUTOMATIC1111) URL adresinizi yazın. + Gradio kütüphanesi kullanılarak uygulanan Stable Diffusion için bir web arayüzü. Bazı sunucu örnekleri:\n• http://192.168.0.2:7860\n• http://alanadiniz.com:7860\n• https://alanadiniz.com - This mode allows you to test the application behavior, even if you don\'t have Stable Diffusion WebUI server.\n\nIn demo mode app ignores user prompt, does not use AI server, and returns some mock images. - Before connecting ensure that:\n• you are running AUTOMATIC1111 WebUI with flags --api --listen\n• your firewall is not blocking 7860 port\n• phone is on the same WiFi with your PC + Bu mod, Stable Diffusion WebUI sunucunuz olmasa bile uygulama davranışını test etmenize olanak tanır.\n\nDemo modunda uygulama kullanıcı istemini görmezden gelir, AI sunucusunu kullanmaz ve bazı sahte görüntüler döndürür. + Bağlanmadan önce şunlardan emin olun:\n• AUTOMATIC1111 WebUI\'yi --api --listen bayraklarıyla çalıştırıyorsunuz\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi'da. + Bağlanmadan önce şunlardan emin olun:\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi'da. Horde AI bulutuna bağlanın Horde AI, Görüntü oluşturma çalışanları ve metin oluşturma çalışanlarından oluşan kitle kaynaklı dağıtılmış bir kümedir. @@ -120,6 +123,9 @@ Hakkında Stability AI Stability AI motoru + Swarm UI URL\'nizi sağlayın + Araçlara kolay erişim, yüksek performans ve genişletilebilirlik üzerine odaklanan Modüler, Kararlı Yaygın Web Kullanıcı Arayüzü. + Bu yapılandırma, telefonunuzda uzak sunucuya/buluta bağlanmaya gerek kalmadan Stable Diffusion AI nesillerini çalıştırmanıza izin verir. Uyarı! Yerel Yayılma işlevi beta testindedir. Yerel modu kullanarak yüksek kaliteli görüntüler beklemeyin. \n\nBu uygulama, güçlü olmayan telefonlarda iyi çalışmayabilir. Oluşturma performansı ve hızı, telefonunuzun kaynaklarına (CPU, RAM) ve oluşturulan görüntünün boyutuna bağlıdır (görüntü boyutu ne kadar küçükse, oluşturma o kadar hızlı olur). diff --git a/presentation/src/main/res/values-uk/strings.xml b/presentation/src/main/res/values-uk/strings.xml index 2b458faa..ccb95b7b 100644 --- a/presentation/src/main/res/values-uk/strings.xml +++ b/presentation/src/main/res/values-uk/strings.xml @@ -62,6 +62,7 @@ Пакетна генерація Підтримка моделей Генерація без інтернету + Власний сервер CFG Шкала: %1$s Метод вибірки Стиль @@ -94,9 +95,11 @@ HTTP Базова Вкажіть свою URL-адресу Stable Diffusion WebUI + Веб-інтерфейс для Stable Diffusion, реалізований за допомогою бібліотеки Gradio. Ось приклади URL-адрес сервера:\nhttp://192.168.0.2:7860\nhttp://yourdomain.com:7860\nhttps://yourdomain.com Цей режим дозволяє перевірити поведінку програми, навіть якщо у вас немає сервера Stable Diffusion WebUI.\n\nУ демонстраційному режимі програма ігнорує параметри генерації, не використовує сервер штучного інтелекту та повертає фіктивні зображення. Перед підключенням переконайтеся, що:\n• ви використовуєте AUTOMATIC1111 WebUI з аргументами --api --listen\n• ваш брандмауер не блокує порт 7860\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК + Перед підключенням переконайтеся, що:\n• ваш брандмауер не блокує порт 7860\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК Підключитися до Horde AI Horde AI — це краудсорсинговий розподілений кластер нод генерації зображень і тексту. @@ -120,6 +123,9 @@ Про Stability AI Stability AI двигун + Provide your Swarm UI URL + Модульний веб-інтерфейс Stable Diffusion з наголосом на полегшення доступу до інструментів, високу продуктивність і розширюваність. + Ця конфігурація дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. УВАГА! Функціональність Local Diffusion у бета-тестуванні. Не очікуйте високоякісних зображень у локальному режимі. \n\nЦя реалізація може не працювати належним чином на телефонах із слабкою потужністю. Продуктивність і швидкість генерації залежать від ресурсів вашого телефону (ЦП, ОЗУ) і розміру згенерованого зображення (чим менший розмір зображення, тим швидше генерується). diff --git a/presentation/src/main/res/values-zh/strings.xml b/presentation/src/main/res/values-zh/strings.xml index e2d9c8bd..5b127047 100644 --- a/presentation/src/main/res/values-zh/strings.xml +++ b/presentation/src/main/res/values-zh/strings.xml @@ -86,6 +86,7 @@ 批量 多个模型 离线生成 + 拥有自己的服务器 CFG缩放: %1$s 采样方法 样式预设 @@ -121,9 +122,11 @@ 提供您的Stable Diffusion WebUI URL + 使用 Gradio 库实现的稳定扩散的 Web 界面。 以下是服务器URL的示例:\n• http://192.168.0.2:7860\n• http://yourdomain.com:7860\n• https://yourdomain.com 此模式允许您测试应用程序的行为,即使您没有Stable Diffusion WebUI服务器。\n\n在演示模式下,应用程序忽略用户提示,不使用AI服务器,并返回一些模拟图像。 在连接之前确保:\n• 您正在运行AUTOMATIC1111 WebUI并带有标志 --api --listen\n• 您的防火墙没有阻止7860端口\n• 手机与您的PC在同一WiFi下 + 在连接之前确保:\n• 您的防火墙没有阻止7860端口\n• 手机与您的PC在同一WiFi下 连接到Horde AI云 @@ -151,6 +154,10 @@ 关于Stability AI Stability AI引擎 + + 提供你的 Swarm UI URL + 模块化稳定扩散 Web 用户界面,重点在于使工具易于访问、高性能和可扩展性。 + 本地扩散 此配置允许在您的手机上运行Stable Diffusion AI生成,无需连接到远程服务器/云。 From 6a74573c9fd4911fcecd04de13b0e3886c99b66e Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 5 Aug 2024 13:12:41 +0300 Subject: [PATCH 08/11] Unit test coverage 1 --- .../Base64EncodingConverter.kt | 36 +++ .../di/ImageProcessingModule.kt | 5 + .../local/SwarmUiModelsLocalDataSource.kt | 6 +- .../mappers/ImageToImagePayloadMappers.kt | 4 +- .../SwarmUiGenerationRemoteDataSource.kt | 24 +- .../remote/SwarmUiSessionDataSourceImpl.kt | 45 ++- .../repository/EmbeddingsRepositoryImpl.kt | 1 + .../data/repository/LorasRepositoryImpl.kt | 1 + .../SwarmUiGenerationRepositoryImpl.kt | 2 + .../repository/SwarmUiModelsRepositoryImpl.kt | 1 + ...rceTest.kt => LorasLocalDataSourceTest.kt} | 2 +- .../local/SwarmUiModelsLocalDataSourceTest.kt | 117 +++++++ .../mocks/SwarmUiGenerationResponseMocks.kt | 7 + .../data/mocks/SwarmUiModelEntityMocks.kt | 12 + .../aisdv1/data/mocks/SwarmUiModelMocks.kt | 11 + .../aisdv1/data/mocks/SwarmUiModelRawMocks.kt | 11 + ...ableDiffusionLorasRemoteDataSourceTest.kt} | 4 +- .../SwarmUiEmbeddingsRemoteDataSourceTest.kt | 79 +++++ .../SwarmUiGenerationRemoteDataSourceTest.kt | 130 ++++++++ .../SwarmUiLorasRemoteDataSourceTest.kt | 79 +++++ .../SwarmUiModelsRemoteDataSourceTest.kt | 79 +++++ .../SwarmUiSessionDataSourceImplTest.kt | 92 ++++++ .../EmbeddingsRepositoryImplTest.kt | 286 +++++++++++++++-- .../repository/LorasRepositoryImplTest.kt | 288 ++++++++++++++++-- .../SwarmUiGenerationRepositoryImplTest.kt | 211 +++++++++++++ .../SwarmUiModelsRepositoryImplTest.kt | 200 ++++++++++++ .../datasource/SwarmUiSessionDataSource.kt | 2 + .../aisdv1/domain/mocks/ConfigurationMocks.kt | 2 + .../aisdv1/domain/mocks/LoraMocks.kt | 16 + .../generation/ImageToImageUseCaseImplTest.kt | 3 + .../generation/TextToImageUseCaseImplTest.kt | 3 + .../sdlora/FetchAndGetLorasUseCaseImplTest.kt | 55 ++++ .../GetConfigurationUseCaseImplTest.kt | 8 + .../SetServerConfigurationUseCaseImplTest.kt | 8 + .../network/api/swarmui/SwarmUiApiImpl.kt | 20 +- .../network/exception/BadSessionException.kt | 3 + .../screen/settings/SettingsScreen.kt | 1 - .../src/main/res/values-tr/strings.xml | 4 +- .../presentation/mocks/SwarmUiModelMocks.kt | 11 + .../screen/setup/ServerSetupViewModelTest.kt | 14 +- .../engine/EngineSelectionViewModelTest.kt | 36 ++- 41 files changed, 1811 insertions(+), 108 deletions(-) create mode 100644 core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt rename data/src/test/java/com/shifthackz/aisdv1/data/local/{StableDiffusionLorasLocalDataSourceTest.kt => LorasLocalDataSourceTest.kt} (98%) create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt rename data/src/test/java/com/shifthackz/aisdv1/data/remote/{StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt => StableDiffusionLorasRemoteDataSourceTest.kt} (95%) create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt create mode 100644 data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt create mode 100644 domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt create mode 100644 domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt create mode 100644 network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt create mode 100644 presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt diff --git a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt new file mode 100644 index 00000000..27542df2 --- /dev/null +++ b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt @@ -0,0 +1,36 @@ +package com.shifthackz.aisdv1.core.imageprocessing + +import com.shifthackz.aisdv1.core.common.log.errorLog +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter.Input +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter.Output +import com.shifthackz.aisdv1.core.imageprocessing.contract.RxImageProcessor +import com.shifthackz.aisdv1.core.imageprocessing.utils.base64DefaultToNoWrap +import io.reactivex.rxjava3.core.Scheduler +import io.reactivex.rxjava3.core.Single + +private typealias Base64EncodingProcessor = RxImageProcessor + +class Base64EncodingConverter( + private val processingScheduler: Scheduler, +) : Base64EncodingProcessor { + + override fun invoke(input: Input): Single = Single + .create { emitter -> + convert(input).fold( + onSuccess = emitter::onSuccess, + onFailure = emitter::onError, + ) + } + .onErrorReturn { t -> + errorLog(t) + Output(input.base64) + } + .subscribeOn(processingScheduler) + + private fun convert(input: Input): Result = runCatching { + Output(base64DefaultToNoWrap(input.base64)) + } + + data class Input(val base64: String) + data class Output(val base64: String) +} diff --git a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt index b7436b0a..5608ffab 100644 --- a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt +++ b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt @@ -2,6 +2,7 @@ package com.shifthackz.aisdv1.core.imageprocessing.di import android.graphics.BitmapFactory import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter import com.shifthackz.aisdv1.core.imageprocessing.R @@ -20,4 +21,8 @@ val imageProcessingModule = module { factory { BitmapToBase64Converter(get().computation) } + + factory { + Base64EncodingConverter(get().computation) + } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt index 003c61a5..a40cec23 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt @@ -17,7 +17,7 @@ internal class SwarmUiModelsLocalDataSource( .queryAll() .map(List::mapEntityToDomain) - override fun insertModels(models: List): Completable = models - .mapDomainToEntity() - .let(dao::insertList) + override fun insertModels(models: List): Completable = dao + .deleteAll() + .andThen(dao.insertList(models.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt index 8e93311c..f764da02 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt @@ -1,7 +1,6 @@ package com.shifthackz.aisdv1.data.mappers import com.shifthackz.aisdv1.core.common.math.roundTo -import com.shifthackz.aisdv1.core.imageprocessing.utils.base64DefaultToNoWrap import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.StabilityAiClipGuidance @@ -104,7 +103,8 @@ fun ImageToImagePayload.mapToSwarmUiRequest( SwarmUiGenerationRequest( sessionId = sessionId, model = swarmUiModel, - initImage = "data:image/png;base64,${base64DefaultToNoWrap(base64Image)}", +// initImage = "data:image/png;base64,${base64DefaultToNoWrap(base64Image)}", + initImage = base64Image, initImageCreativity = denoisingStrength.roundTo(2).toString(), images = 1, prompt = prompt, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt index ed7fda3d..913561b1 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter import com.shifthackz.aisdv1.data.mappers.mapCloudToAiGenResult import com.shifthackz.aisdv1.data.mappers.mapToSwarmUiRequest @@ -17,7 +18,8 @@ import io.reactivex.rxjava3.core.Single class SwarmUiGenerationRemoteDataSource( private val serverUrlProvider: ServerUrlProvider, private val api: SwarmUiApi, - private val converter: BitmapToBase64Converter, + private val bmpToBase64Converter: BitmapToBase64Converter, + private val base64EncodingConverter: Base64EncodingConverter, ) : SwarmUiGenerationDataSource.Remote { override fun textToImage( @@ -35,11 +37,19 @@ class SwarmUiGenerationRemoteDataSource( sessionId: String, model: String, payload: ImageToImagePayload, - ): Single = - generate( - payload = payload, - request = payload.mapToSwarmUiRequest(sessionId, model), - ) + ): Single = payload + .base64Image + .let(Base64EncodingConverter::Input) + .let(base64EncodingConverter::invoke) + .map(Base64EncodingConverter.Output::base64) + .map { base64 -> "data:image/png;base64,${base64}" } + .map { base64Uri -> payload.copy(base64Image = base64Uri) } + .flatMap { encodedPayload -> + generate( + payload = encodedPayload, + request = encodedPayload.mapToSwarmUiRequest(sessionId, model), + ) + } .map(Pair::mapCloudToAiGenResult) private fun generate( @@ -58,7 +68,7 @@ class SwarmUiGenerationRemoteDataSource( } .flatMap(api::downloadImage) .map(BitmapToBase64Converter::Input) - .flatMap(converter::invoke) + .flatMap(bmpToBase64Converter::invoke) .map(BitmapToBase64Converter.Output::base64ImageString) .map { base64 -> payload to base64 } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt index 5e66e0ab..8c191c30 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource import com.shifthackz.aisdv1.domain.preference.SessionPreference import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_SESSION +import com.shifthackz.aisdv1.network.exception.BadSessionException import io.reactivex.rxjava3.core.Single internal class SwarmUiSessionDataSourceImpl( @@ -16,23 +17,35 @@ internal class SwarmUiSessionDataSourceImpl( override fun getSessionId(connectUrl: String?): Single = if (sessionPreference.swarmUiSessionId.isBlank()) { - val chain = connectUrl - ?.let { url -> "$url/$PATH_SESSION".fixUrlSlashes() } - ?.let(api::getNewSession) - ?: serverUrlProvider(PATH_SESSION).flatMap(api::getNewSession) - - chain - .flatMap { response -> - response.sessionId - ?.takeIf(String::isNotBlank) - ?.let { Single.just(it) } - ?: Single.error(IllegalStateException("Bad session ID.")) - } - .map { sessionId -> - sessionPreference.swarmUiSessionId = sessionId - sessionId - } + forceRenew(connectUrl) } else { Single.just(sessionPreference.swarmUiSessionId) } + + override fun forceRenew(connectUrl: String?): Single { + val chain = connectUrl + ?.let { url -> "$url/$PATH_SESSION".fixUrlSlashes() } + ?.let(api::getNewSession) + ?: serverUrlProvider(PATH_SESSION).flatMap(api::getNewSession) + + return chain + .flatMap { response -> + response.sessionId + ?.takeIf(String::isNotBlank) + ?.let { Single.just(it) } + ?: Single.error(IllegalStateException("Bad session ID.")) + } + .map { sessionId -> + sessionPreference.swarmUiSessionId = sessionId + sessionId + } + } + + override fun handleSessionError(chain: Single): Single = chain.onErrorResumeNext { t -> + if (t is BadSessionException) { + forceRenew().flatMap { chain } + } else { + Single.error(t) + } + } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt index 6b7ef62a..d2931538 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt @@ -25,6 +25,7 @@ internal class EmbeddingsRepositoryImpl( ServerSource.SWARM_UI -> swarmSession .getSessionId() .flatMap(rdsSwarm::fetchEmbeddings) + .let(swarmSession::handleSessionError) .flatMapCompletable(lds::insertEmbeddings) else -> Completable.complete() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt index 5e0fecaa..745657e2 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt @@ -25,6 +25,7 @@ internal class LorasRepositoryImpl( ServerSource.SWARM_UI -> swarmSession .getSessionId() .flatMap(rdsSwarm::fetchLoras) + .let(swarmSession::handleSessionError) .flatMapCompletable(lds::insertLoras) else -> Completable.complete() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt index 1ea98007..eccf7001 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt @@ -45,6 +45,7 @@ internal class SwarmUiGenerationRepositoryImpl( payload = payload, ) } + .let(session::handleSessionError) .flatMap(::insertGenerationResult) override fun generateFromImage(payload: ImageToImagePayload): Single = session @@ -56,5 +57,6 @@ internal class SwarmUiGenerationRepositoryImpl( payload = payload, ) } + .let(session::handleSessionError) .flatMap(::insertGenerationResult) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt index cda65ffb..e2d6ca4c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt @@ -16,6 +16,7 @@ internal class SwarmUiModelsRepositoryImpl( override fun fetchModels(): Completable = session .getSessionId() .flatMap(rds::fetchSwarmModels) + .let(session::handleSessionError) .flatMapCompletable(lds::insertModels) override fun fetchAndGetModels(): Single> = fetchModels() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSourceTest.kt similarity index 98% rename from data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSourceTest.kt index 8ddec896..20c265eb 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSourceTest.kt @@ -9,7 +9,7 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class StableDiffusionLorasLocalDataSourceTest { +class LorasLocalDataSourceTest { private val stubException = Throwable("Database error.") private val stubDao = mockk() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt new file mode 100644 index 00000000..34ebbe56 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelEntities +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModels +import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class SwarmUiModelsLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = SwarmUiModelsLocalDataSource(stubDao) + + @Test + fun `given attempt to get models, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockSwarmUiModelEntities) + + localDataSource + .getModels() + .test() + .assertNoErrors() + .assertValue { it.size == mockSwarmUiModels.size } + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert models, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertModels(mockSwarmUiModels) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert models, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertModels(mockSwarmUiModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert models, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertModels(mockSwarmUiModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt new file mode 100644 index 00000000..e92ab561 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse + +val mockSwarmUiGenerationResponse = SwarmUiGenerationResponse( + images = listOf("/tmp/img.jpg"), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt new file mode 100644 index 00000000..128e54f3 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity + +val mockSwarmUiModelEntities = listOf( + SwarmUiModelEntity( + "5598", + "5598", + "5598", + "", + ) +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt new file mode 100644 index 00000000..22d794cd --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel + +val mockSwarmUiModels = listOf( + SwarmUiModel( + name = "5598", + title = "5598", + author = "5598", + ) +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt new file mode 100644 index 00000000..19c9218a --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw + +val mockSwarmUiModelsRaw = listOf( + SwarmUiModelRaw( + name = "5598", + title = "5598", + author = "5598", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt similarity index 95% rename from data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt index 7b0ae55c..a3e129ef 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionStableDiffusionLorasRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt @@ -10,7 +10,7 @@ import io.reactivex.rxjava3.core.Single import org.junit.Before import org.junit.Test -class StableDiffusionStableDiffusionLorasRemoteDataSourceTest { +class StableDiffusionLorasRemoteDataSourceTest { private val stubException = Throwable("Internal server error.") private val stubUrlProvider = mockk() @@ -18,7 +18,7 @@ class StableDiffusionStableDiffusionLorasRemoteDataSourceTest { private val remoteDataSource = StableDiffusionLorasRemoteDataSource( serverUrlProvider = stubUrlProvider, - automatic1111Api = stubApi, + api = stubApi, ) @Before diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt new file mode 100644 index 00000000..d0fbe69d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelsRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiEmbeddingsRemoteDataSourceTest { + + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = SwarmUiEmbeddingsRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(mockSwarmUiModelsRaw)) + + remoteDataSource + .fetchEmbeddings("5598") + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockSwarmUiModelsRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(emptyList())) + + remoteDataSource + .fetchEmbeddings("5598") + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchEmbeddings("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..29c431f4 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt @@ -0,0 +1,130 @@ +package com.shifthackz.aisdv1.data.remote + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiGenerationResponse +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiGenerationRemoteDataSourceTest { + + private val stubBitmap = mockk() + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + private val stubBmpToBase64Converter = mockk() + private val stubBase64EncConverter = mockk() + + private val remoteDataSource = SwarmUiGenerationRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + bmpToBase64Converter = stubBmpToBase64Converter, + base64EncodingConverter = stubBase64EncConverter, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + + @Test + fun `given attempt to generate txt2img, api returns result, expected valid ai generation result value`() { + every { + stubApi.generate(any(), any()) + } returns Single.just(mockSwarmUiGenerationResponse) + + every { + stubApi.downloadImage(any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + remoteDataSource + .textToImage(SESSION_ID, MODEL, mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns error, expected error value`() { + every { + stubApi.generate(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .textToImage(SESSION_ID, MODEL, mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate img2img, api returns result, expected valid ai generation result value`() { + every { + stubBase64EncConverter(any()) + } returns Single.just(Base64EncodingConverter.Output("base64")) + + every { + stubApi.generate(any(), any()) + } returns Single.just(mockSwarmUiGenerationResponse) + + every { + stubApi.downloadImage(any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + remoteDataSource + .imageToImage(SESSION_ID, MODEL, mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate img2img, api returns error, expected error value`() { + every { + stubBase64EncConverter(any()) + } returns Single.error(stubException) + + every { + stubApi.generate(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .imageToImage(SESSION_ID, MODEL, mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + companion object { + private const val SESSION_ID = "5598" + private const val MODEL = "OpenStableDiffusion" + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt new file mode 100644 index 00000000..ed70d23e --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelsRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiLorasRemoteDataSourceTest { + + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = SwarmUiLorasRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(mockSwarmUiModelsRaw)) + + remoteDataSource + .fetchLoras("5598") + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockSwarmUiModelsRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(emptyList())) + + remoteDataSource + .fetchLoras("5598") + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchLoras("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt new file mode 100644 index 00000000..e4a68d0d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelsRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiModelsRemoteDataSourceTest { + + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = SwarmUiModelsRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(mockSwarmUiModelsRaw)) + + remoteDataSource + .fetchSwarmModels("5598") + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockSwarmUiModelsRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(emptyList())) + + remoteDataSource + .fetchSwarmModels("5598") + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchSwarmModels("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt new file mode 100644 index 00000000..6a1a3aa3 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt @@ -0,0 +1,92 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.preference.SessionPreference +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiSessionDataSourceImplTest { + + private val stubApi = mockk() + private val stubSessionPreference = mockk() + private val stubServerUrlProvider = mockk() + + private val remoteDataSource = SwarmUiSessionDataSourceImpl( + api = stubApi, + sessionPreference = stubSessionPreference, + serverUrlProvider = stubServerUrlProvider, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given session present in preference, expected sessionId value from preference`() { + every { + stubSessionPreference::swarmUiSessionId.get() + } returns "5598" + + remoteDataSource + .getSessionId() + .test() + .assertNoErrors() + .await() + .assertValue("5598") + .assertComplete() + } + + @Test + fun `given session NOT present in preference, API returns session, expected sessionId value from API`() { + every { + stubSessionPreference::swarmUiSessionId.get() + } returns "" + + every { + stubSessionPreference::swarmUiSessionId.set(any()) + } returns Unit + + every { + stubApi.getNewSession(any()) + } returns Single.just(SwarmUiSessionResponse("5598")) + + remoteDataSource + .getSessionId() + .test() + .assertNoErrors() + .await() + .assertValue("5598") + .assertComplete() + } + + @Test + fun `given session NOT present in preference, API returns null, expected error value`() { + every { + stubSessionPreference::swarmUiSessionId.get() + } returns "" + + every { + stubSessionPreference::swarmUiSessionId.set(any()) + } returns Unit + + every { + stubApi.getNewSession(any()) + } returns Single.just(SwarmUiSessionResponse(null)) + + remoteDataSource + .getSessionId() + .test() + .assertError { t -> t is IllegalStateException && t.message == "Bad session ID." } + .await() + .assertNoValues() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt index acb46779..74e044b1 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt @@ -2,31 +2,82 @@ package com.shifthackz.aisdv1.data.repository import com.shifthackz.aisdv1.data.mocks.mockEmbeddings import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import io.mockk.every import io.mockk.mockk import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single +import org.junit.Before import org.junit.Test class EmbeddingsRepositoryImplTest { private val stubException = Throwable("Something went wrong.") - private val stubRemoteDataSource = mockk() - private val stubLocalDataSource = mockk() + private val stubRdsA1111 = mockk() + private val stubRdsSwarm = mockk() + private val stubSwarmSession = mockk() + private val stubLds = mockk() + private val stubPreferenceManager = mockk() private val repository = EmbeddingsRepositoryImpl( - rdsA1111 = stubRemoteDataSource, - lds = stubLocalDataSource, + rdsA1111 = stubRdsA1111, + rdsSwarm = stubRdsSwarm, + swarmSession = stubSwarmSession, + lds = stubLds, + preferenceManager = stubPreferenceManager, ) + @Before + fun initialize() { + every { + stubSwarmSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to fetch embeddings, source is AUTOMATIC1111, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + @Test - fun `given attempt to fetch embeddings, remote returns data, local insert success, expected complete value`() { + fun `given attempt to fetch embeddings, source is SWARM_UI, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + every { - stubRemoteDataSource.fetchEmbeddings() + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) } returns Single.just(mockEmbeddings) every { - stubLocalDataSource.insertEmbeddings(any()) + stubLds.insertEmbeddings(any()) } returns Completable.complete() repository @@ -38,13 +89,17 @@ class EmbeddingsRepositoryImplTest { } @Test - fun `given attempt to fetch embeddings, remote throws exception, local insert success, expected error value`() { + fun `given attempt to fetch embeddings, source is AUTOMATIC1111, remote throws exception, local insert success, expected error value`() { every { - stubRemoteDataSource.fetchEmbeddings() + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() } returns Single.error(stubException) every { - stubLocalDataSource.insertEmbeddings(any()) + stubLds.insertEmbeddings(any()) } returns Completable.complete() repository @@ -56,13 +111,77 @@ class EmbeddingsRepositoryImplTest { } @Test - fun `given attempt to fetch embeddings, remote returns data, local insert fails, expected error value`() { + fun `given attempt to fetch embeddings, source is SWARM_UI, remote throws exception, local insert success, expected error value`() { every { - stubRemoteDataSource.fetchEmbeddings() + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) + } returns Single.error(stubException) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is AUTOMATIC1111, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() } returns Single.just(mockEmbeddings) every { - stubLocalDataSource.insertEmbeddings(any()) + stubLds.insertEmbeddings(any()) + } returns Completable.error(stubException) + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is SWARM_UI, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) } returns Completable.error(stubException) repository @@ -76,7 +195,7 @@ class EmbeddingsRepositoryImplTest { @Test fun `given attempt to get embeddings, local data source returns list, expected valid domain models list value`() { every { - stubLocalDataSource.getEmbeddings() + stubLds.getEmbeddings() } returns Single.just(mockEmbeddings) repository @@ -91,7 +210,7 @@ class EmbeddingsRepositoryImplTest { @Test fun `given attempt to get embeddings, local data source returns empty list, expected empty domain models list value`() { every { - stubLocalDataSource.getEmbeddings() + stubLds.getEmbeddings() } returns Single.just(emptyList()) repository @@ -106,7 +225,7 @@ class EmbeddingsRepositoryImplTest { @Test fun `given attempt to get embeddings, local data source throws exception, expected error value`() { every { - stubLocalDataSource.getEmbeddings() + stubLds.getEmbeddings() } returns Single.error(stubException) repository @@ -119,17 +238,56 @@ class EmbeddingsRepositoryImplTest { } @Test - fun `given attempt to fetch and get embeddings, remote returns data, local returns data, expected valid domain models list value`() { + fun `given attempt to fetch and get embeddings, source is AUTOMATIC1111, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is SWARM_UI, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + every { - stubRemoteDataSource.fetchEmbeddings() + stubRdsSwarm.fetchEmbeddings(any()) } returns Single.just(mockEmbeddings) every { - stubLocalDataSource.insertEmbeddings(any()) + stubLds.insertEmbeddings(any()) } returns Completable.complete() every { - stubLocalDataSource.getEmbeddings() + stubLds.getEmbeddings() } returns Single.just(mockEmbeddings) repository @@ -142,17 +300,21 @@ class EmbeddingsRepositoryImplTest { } @Test - fun `given attempt to fetch and get embeddings, remote fails, local returns data, expected valid domain models list value`() { + fun `given attempt to fetch and get embeddings, source is AUTOMATIC1111, remote fails, local returns data, expected valid domain models list value`() { every { - stubRemoteDataSource.fetchEmbeddings() + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() } returns Single.error(stubException) every { - stubLocalDataSource.insertEmbeddings(any()) + stubLds.insertEmbeddings(any()) } returns Completable.complete() every { - stubLocalDataSource.getEmbeddings() + stubLds.getEmbeddings() } returns Single.just(mockEmbeddings) repository @@ -165,13 +327,83 @@ class EmbeddingsRepositoryImplTest { } @Test - fun `given attempt to fetch and get embeddings, remote fails, local fails, expected valid error value`() { + fun `given attempt to fetch and get embeddings, source is SWARM_UI, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is AUTOMATIC1111, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.getEmbeddings() + } returns Single.error(stubException) + + repository + .fetchAndGetEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is SWARM_UI, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + every { - stubRemoteDataSource.fetchEmbeddings() + stubRdsA1111.fetchEmbeddings() } returns Single.error(stubException) every { - stubLocalDataSource.getEmbeddings() + stubLds.getEmbeddings() } returns Single.error(stubException) repository diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt index 97513ffe..3bff4b37 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt @@ -2,31 +2,52 @@ package com.shifthackz.aisdv1.data.repository import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoras import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import io.mockk.every import io.mockk.mockk import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single +import org.junit.Before import org.junit.Test class LorasRepositoryImplTest { private val stubException = Throwable("Something went wrong.") - private val stubRemoteDataSource = mockk() - private val stubLocalDataSource = mockk() - + private val stubRdsA1111 = mockk() + private val stubRdsSwarm = mockk() + private val stubSwarmSession = mockk() + private val stubLds = mockk() + private val stubPreferenceManager = mockk() + private val repository = LorasRepositoryImpl( - rdsA1111 = stubRemoteDataSource, - lds = stubLocalDataSource, + rdsA1111 = stubRdsA1111, + rdsSwarm = stubRdsSwarm, + swarmSession = stubSwarmSession, + lds = stubLds, + preferenceManager = stubPreferenceManager, ) + @Before + fun initialize() { + every { + stubSwarmSession.handleSessionError(any>()) + } returnsArgument 0 + } + @Test - fun `given attempt to fetch loras, remote returns data, local insert success, expected complete value`() { + fun `given attempt to fetch loras, source is AUTOMATIC1111, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + every { - stubRemoteDataSource.fetchLoras() + stubRdsA1111.fetchLoras() } returns Single.just(mockStableDiffusionLoras) every { - stubLocalDataSource.insertLoras(any()) + stubLds.insertLoras(any()) } returns Completable.complete() repository @@ -38,13 +59,47 @@ class LorasRepositoryImplTest { } @Test - fun `given attempt to fetch loras, remote throws exception, local insert success, expected error value`() { + fun `given attempt to fetch loras, source is SWARM_UI, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + every { - stubRemoteDataSource.fetchLoras() + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch loras, source is AUTOMATIC1111, remote throws exception, local insert success, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() } returns Single.error(stubException) every { - stubLocalDataSource.insertLoras(any()) + stubLds.insertLoras(any()) } returns Completable.complete() repository @@ -56,13 +111,77 @@ class LorasRepositoryImplTest { } @Test - fun `given attempt to fetch loras, remote returns data, local insert fails, expected error value`() { + fun `given attempt to fetch loras, source is SWARM_UI, remote throws exception, local insert success, expected error value`() { every { - stubRemoteDataSource.fetchLoras() + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.error(stubException) + + every { + stubSwarmSession.getSessionId() + } returns Single.error(stubException) + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.error(stubException) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch loras, source is AUTOMATIC1111, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.error(stubException) + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch loras, source is SWARM_UI, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) } returns Single.just(mockStableDiffusionLoras) every { - stubLocalDataSource.insertLoras(any()) + stubLds.insertLoras(any()) } returns Completable.error(stubException) repository @@ -76,7 +195,7 @@ class LorasRepositoryImplTest { @Test fun `given attempt to get loras, local data source returns list, expected valid domain models list value`() { every { - stubLocalDataSource.getLoras() + stubLds.getLoras() } returns Single.just(mockStableDiffusionLoras) repository @@ -91,7 +210,7 @@ class LorasRepositoryImplTest { @Test fun `given attempt to get loras, local data source returns empty list, expected empty domain models list value`() { every { - stubLocalDataSource.getLoras() + stubLds.getLoras() } returns Single.just(emptyList()) repository @@ -106,7 +225,7 @@ class LorasRepositoryImplTest { @Test fun `given attempt to get loras, local data source throws exception, expected error value`() { every { - stubLocalDataSource.getLoras() + stubLds.getLoras() } returns Single.error(stubException) repository @@ -119,17 +238,21 @@ class LorasRepositoryImplTest { } @Test - fun `given attempt to fetch and get loras, remote returns data, local returns data, expected valid domain models list value`() { + fun `given attempt to fetch and get loras, source is AUTOMATIC1111, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + every { - stubRemoteDataSource.fetchLoras() + stubRdsA1111.fetchLoras() } returns Single.just(mockStableDiffusionLoras) every { - stubLocalDataSource.insertLoras(any()) + stubLds.insertLoras(any()) } returns Completable.complete() every { - stubLocalDataSource.getLoras() + stubLds.getLoras() } returns Single.just(mockStableDiffusionLoras) repository @@ -142,17 +265,91 @@ class LorasRepositoryImplTest { } @Test - fun `given attempt to fetch and get loras, remote fails, local returns data, expected valid domain models list value`() { + fun `given attempt to fetch and get loras, source is SWARM_UI, remote returns data, local returns data, expected valid domain models list value`() { every { - stubRemoteDataSource.fetchLoras() + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is AUTOMATIC1111, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.error(stubException) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is SWARM_UI, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) } returns Single.error(stubException) every { - stubLocalDataSource.insertLoras(any()) + stubLds.insertLoras(any()) } returns Completable.complete() every { - stubLocalDataSource.getLoras() + stubLds.getLoras() } returns Single.just(mockStableDiffusionLoras) repository @@ -165,13 +362,48 @@ class LorasRepositoryImplTest { } @Test - fun `given attempt to fetch and get loras, remote fails, local fails, expected valid error value`() { + fun `given attempt to fetch and get loras, source is AUTOMATIC1111, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.error(stubException) + + every { + stubLds.getLoras() + } returns Single.error(stubException) + + repository + .fetchAndGetLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is SWARM_UI, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + every { - stubRemoteDataSource.fetchLoras() + stubRdsSwarm.fetchLoras(any()) } returns Single.error(stubException) every { - stubLocalDataSource.getLoras() + stubLds.getLoras() } returns Single.error(stubException) repository diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..53e09a05 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt @@ -0,0 +1,211 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiGenerationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubRemoteDataSource = mockk() + private val stubSession = mockk() + private val stubPreferenceManager = mockk() + + private val repository = SwarmUiGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + remoteDataSource = stubRemoteDataSource, + session = stubSession, + preferenceManager = stubPreferenceManager, + ) + + @Before + fun initialize() { + every { + stubPreferenceManager.autoSaveAiResults + } returns false + + every { + stubSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to check api availability, remote completes, expected complete value`() { + every { + stubSession.getSessionId() + } returns Single.just("5598") + + repository + .checkApiAvailability() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to check api availability, remote throws exception, expected error value`() { + every { + stubSession.getSessionId() + } returns Single.error(stubException) + + repository + .checkApiAvailability() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to check api availability by url, remote completes, expected complete value`() { + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + repository + .checkApiAvailability("https://5598.is.my.favourite.com") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to check api availability by url, remote throws exception, expected error value`() { + every { + stubSession.getSessionId(any()) + } returns Single.error(stubException) + + repository + .checkApiAvailability("https://5598.is.my.favourite.com") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, remote returns result, expected valid domain model value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.textToImage(any(), any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, remote throws exception, expected error value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.textToImage(any(), any(), any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, remote returns result, expected valid domain model value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.imageToImage(any(), any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, remote throws exception, expected error value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.imageToImage(any(), any(), any()) + } returns Single.error(stubException) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt new file mode 100644 index 00000000..ca6cb6cf --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt @@ -0,0 +1,200 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModels +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiModelsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubSession = mockk() + private val stubRds = mockk() + private val stubLds = mockk() + + private val repository = SwarmUiModelsRepositoryImpl(stubSession, stubRds, stubLds) + + @Before + fun initialize() { + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to fetch models, remote returns data, local insert success, expected complete value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.just(mockSwarmUiModels) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + repository + .fetchModels() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, remote throws exception, local insert success, expected error value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.error(stubException) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + repository + .fetchModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch models, remote returns data, local insert fails, expected error value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.just(mockSwarmUiModels) + + every { + stubLds.insertModels(any()) + } returns Completable.error(stubException) + + repository + .fetchModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get models, local data source returns list, expected valid domain models list value`() { + every { + stubLds.getModels() + } returns Single.just(mockSwarmUiModels) + + repository + .getModels() + .test() + .assertNoErrors() + .assertValue(mockSwarmUiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source returns empty list, expected empty domain models list value`() { + every { + stubLds.getModels() + } returns Single.just(emptyList()) + + repository + .getModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source throws exception, expected error value`() { + every { + stubLds.getModels() + } returns Single.error(stubException) + + repository + .getModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get models, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.just(mockSwarmUiModels) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + every { + stubLds.getModels() + } returns Single.just(mockSwarmUiModels) + + repository + .fetchAndGetModels() + .test() + .assertNoErrors() + .assertValue(mockSwarmUiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.error(stubException) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + every { + stubLds.getModels() + } returns Single.just(mockSwarmUiModels) + + repository + .fetchAndGetModels() + .test() + .assertNoErrors() + .assertValue(mockSwarmUiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local fails, expected valid error value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.error(stubException) + + every { + stubLds.getModels() + } returns Single.error(stubException) + + repository + .fetchAndGetModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt index 0324599a..af6807db 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt @@ -4,4 +4,6 @@ import io.reactivex.rxjava3.core.Single interface SwarmUiSessionDataSource { fun getSessionId(connectUrl: String? = null): Single + fun forceRenew(connectUrl: String? = null): Single + fun handleSessionError(chain: Single): Single } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt index 84e9f9da..114e97b4 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -4,6 +4,8 @@ import com.shifthackz.aisdv1.domain.entity.Configuration val mockConfiguration = Configuration( serverUrl = "http://5598.is.my.favorite.com", + swarmUiUrl = "http://5598.is.my.favorite.com", + swarmUiModel = "5598", hordeApiKey = "5598", openAiApiKey = "5598", huggingFaceApiKey = "5598", diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt new file mode 100644 index 00000000..7e457a43 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.LoRA + +val mockLoRAs = listOf( + LoRA( + name = "5598", + alias = "5598", + path = "/unknown", + ), + LoRA( + name = "151297", + alias = "151297", + path = "/unknown", + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt index 2a1544b3..9f2fac3b 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Single import org.junit.Test @@ -18,6 +19,7 @@ class ImageToImageUseCaseImplTest { private val stubException = Throwable("Unable to generate image.") private val stubStableDiffusionGenerationRepository = mock() + private val stubSwarmUiGenerationRepository = mock() private val stubHordeGenerationRepository = mock() private val stubHuggingFaceGenerationRepository = mock() private val stubStabilityAiGenerationRepository = mock() @@ -25,6 +27,7 @@ class ImageToImageUseCaseImplTest { private val useCase = ImageToImageUseCaseImpl( stableDiffusionGenerationRepository = stubStableDiffusionGenerationRepository, + swarmUiGenerationRepository = stubSwarmUiGenerationRepository, hordeGenerationRepository = stubHordeGenerationRepository, huggingFaceGenerationRepository = stubHuggingFaceGenerationRepository, stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index 1283e409..b6d3bced 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -13,6 +13,7 @@ import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepositor import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Single import org.junit.Test @@ -24,6 +25,7 @@ class TextToImageUseCaseImplTest { private val stubHuggingFaceGenerationRepository = mock() private val stubOpenAiGenerationRepository = mock() private val stubStabilityAiGenerationRepository = mock() + private val stubSwarmUiGenerationRepository = mock() private val stubLocalDiffusionGenerationRepository = mock() private val stubPreferenceManager = mock() @@ -34,6 +36,7 @@ class TextToImageUseCaseImplTest { openAiGenerationRepository = stubOpenAiGenerationRepository, stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, + swarmUiGenerationRepository = stubSwarmUiGenerationRepository, preferenceManager = stubPreferenceManager, ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt new file mode 100644 index 00000000..643bf9f2 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.sdlora + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockLoRAs +import com.shifthackz.aisdv1.domain.repository.LorasRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class FetchAndGetLorasUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = FetchAndGetLorasUseCaseImpl(stubRepository) + + @Test + fun `given repository provided list of LoRAs, expected valid list value`() { + whenever(stubRepository.fetchAndGetLoras()) + .thenReturn(Single.just(mockLoRAs)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockLoRAs) + .await() + .assertComplete() + } + + @Test + fun `given repository provided empty list of LoRAs, expected empty list value`() { + whenever(stubRepository.fetchAndGetLoras()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Unknown error occurred.") + + whenever(stubRepository.fetchAndGetLoras()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index b49d7b08..7007cdb9 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -28,6 +28,14 @@ class GetConfigurationUseCaseImplTest { stubPreferenceManager::automatic1111serverUrl.get() } returns mockConfiguration.serverUrl + every { + stubPreferenceManager::swarmUiServerUrl.get() + } returns mockConfiguration.swarmUiUrl + + every { + stubPreferenceManager::swarmUiModel.get() + } returns mockConfiguration.swarmUiModel + every { stubPreferenceManager::demoMode.get() } returns mockConfiguration.demoMode diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index 0bf8b0cf..ef1bc037 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -31,6 +31,14 @@ class SetServerConfigurationUseCaseImplTest { stubPreferenceManager::automatic1111serverUrl.set(any()) } returns Unit + every { + stubPreferenceManager::swarmUiModel.set(any()) + } returns Unit + + every { + stubPreferenceManager::swarmUiServerUrl.set(any()) + } returns Unit + every { stubPreferenceManager::demoMode.set(any()) } returns Unit diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt index e17a46a7..84a94a64 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt @@ -2,12 +2,14 @@ package com.shifthackz.aisdv1.network.api.swarmui import android.graphics.Bitmap import android.graphics.BitmapFactory +import com.shifthackz.aisdv1.network.exception.BadSessionException import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse import io.reactivex.rxjava3.core.Single +import retrofit2.HttpException internal class SwarmUiApiImpl( private val rawApi: SwarmUiApi.RawApi, @@ -15,19 +17,25 @@ internal class SwarmUiApiImpl( override fun getNewSession(url: String): Single = rawApi .getNewSession(url, emptyMap()) + .mapError() override fun generate( url: String, request: SwarmUiGenerationRequest, - ): Single = rawApi.generate(url, request) + ): Single = rawApi + .generate(url, request) + .mapError() override fun fetchModels( url: String, request: SwarmUiModelsRequest - ): Single = rawApi.fetchModels(url, request) + ): Single = rawApi + .fetchModels(url, request) + .mapError() override fun downloadImage(url: String): Single = rawApi .download(url) + .mapError() .flatMap { response -> response.body() ?.bytes() @@ -35,4 +43,12 @@ internal class SwarmUiApiImpl( ?.let { Single.just(it) } ?: Single.error(Throwable("Body is null")) } + + private fun Single.mapError(): Single = this.onErrorResumeNext { t -> + if (t is HttpException && t.code() == 401) { + Single.error(BadSessionException()) + } else { + Single.error(t) + } + } } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt b/network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt new file mode 100644 index 00000000..c0589290 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt @@ -0,0 +1,3 @@ +package com.shifthackz.aisdv1.network.exception + +class BadSessionException : Throwable() diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index 06fb45ed..ec4b47ad 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -107,7 +107,6 @@ fun SettingsScreen() { } } - @Composable private fun ScreenContent( modifier: Modifier = Modifier, diff --git a/presentation/src/main/res/values-tr/strings.xml b/presentation/src/main/res/values-tr/strings.xml index 86cf57bf..429c7598 100644 --- a/presentation/src/main/res/values-tr/strings.xml +++ b/presentation/src/main/res/values-tr/strings.xml @@ -98,8 +98,8 @@ Gradio kütüphanesi kullanılarak uygulanan Stable Diffusion için bir web arayüzü. Bazı sunucu örnekleri:\n• http://192.168.0.2:7860\n• http://alanadiniz.com:7860\n• https://alanadiniz.com Bu mod, Stable Diffusion WebUI sunucunuz olmasa bile uygulama davranışını test etmenize olanak tanır.\n\nDemo modunda uygulama kullanıcı istemini görmezden gelir, AI sunucusunu kullanmaz ve bazı sahte görüntüler döndürür. - Bağlanmadan önce şunlardan emin olun:\n• AUTOMATIC1111 WebUI\'yi --api --listen bayraklarıyla çalıştırıyorsunuz\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi'da. - Bağlanmadan önce şunlardan emin olun:\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi'da. + Bağlanmadan önce şunlardan emin olun:\n• AUTOMATIC1111 WebUI\'yi --api --listen bayraklarıyla çalıştırıyorsunuz\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. + Bağlanmadan önce şunlardan emin olun:\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. Horde AI bulutuna bağlanın Horde AI, Görüntü oluşturma çalışanları ve metin oluşturma çalışanlarından oluşan kitle kaynaklı dağıtılmış bir kümedir. diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt new file mode 100644 index 00000000..167c49a1 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel + +val mockSwarmUiModels = listOf( + SwarmUiModel( + name = "5598", + title = "5598", + author = "ShiftHackZ", + ) +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index e6bb063c..7be36cd4 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -26,7 +26,6 @@ import io.mockk.verify import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single -import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.test.runTest import org.junit.Assert import org.junit.Before @@ -36,8 +35,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { private val stubGetConfigurationUseCase = mockk() private val stubGetLocalAiModelsUseCase = mockk() - private val stubFetchAndGetHuggingFaceModelsUseCase = - mockk() + private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubUrlValidator = mockk() private val stubCommonStringValidator = mockk() private val stubSetupConnectionInterActor = mockk() @@ -342,16 +340,6 @@ class ServerSetupViewModelTest : CoreViewModelTest() { } } - @Test - fun `given received LaunchManageStoragePermission intent, expected LaunchManageStoragePermission effect delivered to effect collector`() { - viewModel.processIntent(ServerSetupIntent.LaunchManageStoragePermission) - runTest { - val expected = ServerSetupEffect.LaunchManageStoragePermission - val actual = viewModel.effect.firstOrNull() - Assert.assertEquals(expected, actual) - } - } - @Test fun `given received NavigateBack intent, expected router navigateBack() method called`() { every { diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt index cf2d83cb..b85b43fc 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt @@ -11,11 +11,13 @@ import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseC import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCase +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest import com.shifthackz.aisdv1.presentation.mocks.mockHuggingFaceModels import com.shifthackz.aisdv1.presentation.mocks.mockLocalAiModels import com.shifthackz.aisdv1.presentation.mocks.mockStabilityAiEngines import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.presentation.mocks.mockSwarmUiModels import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider import io.mockk.every import io.mockk.mockk @@ -42,10 +44,9 @@ class EngineSelectionViewModelTest : CoreViewModelTest private val stubSelectStableDiffusionModelUseCase = mockk() private val stubGetStableDiffusionModelsUseCase = mockk() private val stubObserveLocalAiModelsUseCase = mockk() - private val stubFetchAndGetStabilityAiEnginesUseCase = - mockk() - private val stubFetchAndGetHuggingFaceModelsUseCase = - mockk() + private val stubFetchAndGetStabilityAiEnginesUseCase = mockk() + private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() + private val stubFetchAndGetSwarmUiModelsUseCase = mockk() override fun initializeViewModel() = EngineSelectionViewModel( preferenceManager = stubPreferenceManager, @@ -56,6 +57,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest observeLocalAiModelsUseCase = stubObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase = stubFetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, + fetchAndGetSwarmUiModelsUseCase = stubFetchAndGetSwarmUiModelsUseCase, ) @Before @@ -106,6 +108,8 @@ class EngineSelectionViewModelTest : CoreViewModelTest selectedStEngine = "5598", localAiModels = listOf(LocalAiModel.CUSTOM), selectedLocalAiModelId = "CUSTOM", + swarmModels = listOf("5598"), + selectedSwarmModel = "5598", ) val actual = viewModel.state.value Assert.assertEquals(expected, actual) @@ -176,6 +180,21 @@ class EngineSelectionViewModelTest : CoreViewModelTest } } + @Test + fun `given received EngineSelectionIntent, source is SWARM_UI, expected swarmModel changed in preference`() { + mockInitialData(DataTestCase.Mock, ServerSource.SWARM_UI) + + every { + stubPreferenceManager::swarmUiModel.set(any()) + } returns Unit + + viewModel.processIntent(EngineSelectionIntent("151297")) + + verify { + stubPreferenceManager::swarmUiModel.set("151297") + } + } + @Test fun `given received EngineSelectionIntent, source is HUGGING_FACE, expected huggingFaceModel changed in preference`() { mockInitialData(DataTestCase.Mock, ServerSource.HUGGING_FACE) @@ -240,6 +259,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest Configuration( huggingFaceModel = "prompthero/openjourney-v4", stabilityAiEngineId = "5598", + swarmUiModel = "5598", localModelId = "CUSTOM", source = source, ), @@ -273,6 +293,14 @@ class EngineSelectionViewModelTest : CoreViewModelTest DataTestCase.Exception -> Single.error(stubException) } + every { + stubFetchAndGetSwarmUiModelsUseCase() + } returns when (testCase) { + DataTestCase.Mock -> Single.just(mockSwarmUiModels) + DataTestCase.Empty -> Single.just(emptyList()) + DataTestCase.Exception -> Single.error(stubException) + } + stubLocalAiModels.onNext( when (testCase) { DataTestCase.Mock -> Result.success(mockLocalAiModels) From 024e41724b9219e9c6a874d5b4580d17f5249dea Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 5 Aug 2024 14:19:14 +0300 Subject: [PATCH 09/11] Update empty UI states --- .../network/interceptor/LoggingInterceptor.kt | 3 +-- .../modal/embedding/EmbeddingScreen.kt | 17 ++++++++++--- .../modal/embedding/EmbeddingState.kt | 2 ++ .../modal/embedding/EmbeddingViewModel.kt | 18 ++++++++++++- .../presentation/modal/extras/ExtrasScreen.kt | 25 +++++++++++++++---- .../presentation/modal/extras/ExtrasState.kt | 2 ++ .../modal/extras/ExtrasViewModel.kt | 19 +++++++++++++- .../navigation/graph/MainNavGraph.kt | 2 +- .../navigation/router/main/MainRouterImpl.kt | 2 +- .../screen/setup/ServerSetupState.kt | 9 +++---- .../components/ConfigurationModeButton.kt | 11 ++------ .../widget/source/ServerSourceLabel.kt | 19 ++++++++++++++ .../src/main/res/values-ru/strings.xml | 12 ++++----- .../src/main/res/values-tr/strings.xml | 12 ++++----- .../src/main/res/values-uk/strings.xml | 12 ++++----- .../src/main/res/values-zh/strings.xml | 12 ++++----- presentation/src/main/res/values/strings.xml | 8 +++--- .../modal/embedding/EmbeddingViewModelTest.kt | 14 +++++++++++ .../modal/extras/ExtrasViewModelTest.kt | 8 ++++++ .../router/main/MainRouterImplTest.kt | 4 +-- 20 files changed, 153 insertions(+), 58 deletions(-) create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt b/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt index b592f9b5..e138632e 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/interceptor/LoggingInterceptor.kt @@ -8,8 +8,7 @@ internal class LoggingInterceptor { fun get() = HttpLoggingInterceptor { message -> debugLog(HTTP_TAG, message) }.apply { -// level = HttpLoggingInterceptor.Level.HEADERS - level = HttpLoggingInterceptor.Level.BODY + level = HttpLoggingInterceptor.Level.HEADERS } companion object { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt index 16b72387..57746dcc 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt @@ -47,10 +47,12 @@ import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties import com.shifthackz.aisdv1.core.extensions.shimmer import com.shifthackz.aisdv1.core.ui.MviComponent +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.widget.error.ErrorComposable +import com.shifthackz.aisdv1.presentation.widget.source.getName import com.shifthackz.aisdv1.presentation.widget.toolbar.ModalDialogToolbar import org.koin.androidx.compose.koinViewModel @@ -209,7 +211,7 @@ private fun ScreenContent( } else { if (state.embeddings.isEmpty()) { item(key = "empty_state") { - EmbeddingEmptyState() + EmbeddingEmptyState(state.source) } } else { items( @@ -234,7 +236,7 @@ private fun ScreenContent( } @Composable -private fun EmbeddingEmptyState() { +private fun EmbeddingEmptyState(source: ServerSource) { Column( verticalArrangement = Arrangement.Center, ) { @@ -243,12 +245,21 @@ private fun EmbeddingEmptyState() { text = stringResource(id = R.string.extras_empty_title), fontSize = 20.sp, ) + val path = when (source) { + ServerSource.AUTOMATIC1111 -> "./embeddings" + ServerSource.SWARM_UI -> "./Models/Embeddings" + else -> "" + } Text( modifier = Modifier .padding(top = 16.dp) .padding(horizontal = 16.dp) .align(Alignment.CenterHorizontally), - text = stringResource(id = R.string.extras_empty_sub_title_embedding), + text = stringResource( + id = R.string.extras_empty_sub_title_embedding, + source.getName(), + path, + ), textAlign = TextAlign.Center, ) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt index 93338f60..9898c607 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt @@ -1,12 +1,14 @@ package com.shifthackz.aisdv1.presentation.modal.embedding import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.android.core.mvi.MviState @Immutable data class EmbeddingState( val loading: Boolean = true, + val source: ServerSource = ServerSource.AUTOMATIC1111, val error: ErrorState = ErrorState.None, val prompt: String = "", val negativePrompt: String = "", diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt index 4bc3eec7..70f75d31 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect import com.shifthackz.aisdv1.presentation.model.ErrorState @@ -12,11 +13,18 @@ import io.reactivex.rxjava3.kotlin.subscribeBy class EmbeddingViewModel( private val fetchAndGetEmbeddingsUseCase: FetchAndGetEmbeddingsUseCase, + private val preferenceManager: PreferenceManager, private val schedulersProvider: SchedulersProvider, ) : MviRxViewModel() { override val initialState = EmbeddingState() + init { + updateState { + it.copy(source = preferenceManager.source) + } + } + override fun processIntent(intent: EmbeddingIntent) { when (intent) { EmbeddingIntent.ApplyNewPrompts -> emitEffect( @@ -52,7 +60,14 @@ class EmbeddingViewModel( } fun updateData(prompt: String, negativePrompt: String) = !fetchAndGetEmbeddingsUseCase() - .doOnSubscribe { updateState { it.copy(loading = true) } } + .doOnSubscribe { + updateState { state -> + state.copy( + loading = true, + source = preferenceManager.source, + ) + } + } .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = { t -> @@ -63,6 +78,7 @@ class EmbeddingViewModel( updateState { state -> state.copy( loading = false, + source = preferenceManager.source, error = ErrorState.None, prompt = prompt, negativePrompt = negativePrompt, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt index 2d669744..efee8a6d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt @@ -44,10 +44,12 @@ import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties import com.shifthackz.aisdv1.core.extensions.shimmer import com.shifthackz.aisdv1.core.ui.MviComponent +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.model.ExtraType import com.shifthackz.aisdv1.presentation.widget.error.ErrorComposable +import com.shifthackz.aisdv1.presentation.widget.source.getName import com.shifthackz.aisdv1.presentation.widget.toolbar.ModalDialogToolbar import org.koin.androidx.compose.koinViewModel @@ -164,7 +166,7 @@ private fun ScreenContent( } else { if (state.loras.isEmpty()) { item(key = "empty_state") { - ExtrasEmptyState(type = state.type) + ExtrasEmptyState(state.type, state.source) } } else { items( @@ -188,7 +190,7 @@ private fun ScreenContent( } @Composable -private fun ExtrasEmptyState(type: ExtraType) { +private fun ExtrasEmptyState(type: ExtraType, source: ServerSource) { Column( verticalArrangement = Arrangement.Center, ) { @@ -197,17 +199,30 @@ private fun ExtrasEmptyState(type: ExtraType) { text = stringResource(id = R.string.extras_empty_title), fontSize = 20.sp, ) + val path = when (type) { + ExtraType.Lora -> when (source) { + ServerSource.AUTOMATIC1111 -> "../models/Lora" + ServerSource.SWARM_UI -> "../Models/Lora" + else -> "" + } + ExtraType.HyperNet -> when (source) { + ServerSource.AUTOMATIC1111 -> "../models/hypernetworks" + ServerSource.SWARM_UI -> "" + else -> "" + } + } Text( modifier = Modifier .padding(top = 16.dp) .padding(horizontal = 16.dp) .align(Alignment.CenterHorizontally), text = stringResource( - id = when (type) { - //ToDo change empty state path depending on A1111/SWARM provider + when (type) { ExtraType.Lora -> R.string.extras_empty_sub_title_lora ExtraType.HyperNet -> R.string.extras_empty_sub_title_hypernet - } + }, + source.getName(), + path, ), textAlign = TextAlign.Center, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt index e75460ac..5860bfe0 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.modal.extras import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.model.ExtraType import com.shifthackz.android.core.mvi.MviState @@ -8,6 +9,7 @@ import com.shifthackz.android.core.mvi.MviState @Immutable data class ExtrasState( val loading: Boolean = true, + val source: ServerSource = ServerSource.AUTOMATIC1111, val error: ErrorState = ErrorState.None, val prompt: String = "", val negativePrompt: String = "", diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt index 53e53670..52f6e3a5 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt @@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase import com.shifthackz.aisdv1.presentation.model.ErrorState @@ -18,11 +19,18 @@ class ExtrasViewModel( private val fetchAndGetLorasUseCase: FetchAndGetLorasUseCase, private val fetchAndGetHyperNetworksUseCase: FetchAndGetHyperNetworksUseCase, private val schedulersProvider: SchedulersProvider, + private val preferenceManager: PreferenceManager, private val timeProvider: TimeProvider, ) : MviRxViewModel() { override val initialState = ExtrasState() + init { + updateState { + it.copy(source = preferenceManager.source) + } + } + override fun processIntent(intent: ExtrasIntent) { when (intent) { ExtrasIntent.ApplyPrompts -> emitEffect( @@ -53,7 +61,15 @@ class ExtrasViewModel( ExtraType.Lora -> fetchAndGetLorasUseCase() ExtraType.HyperNet -> fetchAndGetHyperNetworksUseCase() } - .doOnSubscribe { updateState { it.copy(loading = true, type = type) } } + .doOnSubscribe { + updateState { state -> + state.copy( + loading = true, + type = type, + source = preferenceManager.source, + ) + } + } .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = { t -> @@ -64,6 +80,7 @@ class ExtrasViewModel( updateState { state -> state.copy( loading = false, + source = preferenceManager.source, error = ErrorState.None, prompt = prompt, negativePrompt = negativePrompt, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt index 7ba8081e..5c0bf6d5 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt @@ -31,7 +31,7 @@ fun NavGraphBuilder.mainNavGraph() { ComposeNavigator.Destination(provider[ComposeNavigator::class]) { entry -> val sourceKey = entry.arguments ?.getInt(Constants.PARAM_SOURCE) - ?: ServerSetupLaunchSource.SPLASH.key + ?: ServerSetupLaunchSource.SPLASH.ordinal ServerSetupScreen(launchSourceKey = sourceKey) }.apply { route = Constants.ROUTE_SERVER_SETUP_FULL diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt index 23399f95..4e9978b7 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt @@ -34,7 +34,7 @@ internal class MainRouterImpl( } override fun navigateToServerSetup(source: ServerSetupLaunchSource) { - effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.key}") { + effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.ordinal}") { if (source == ServerSetupLaunchSource.SPLASH) { popUpTo(Constants.ROUTE_SPLASH) { inclusive = true diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 4fab4c96..555a4d32 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -85,13 +85,12 @@ data class ServerSetupState( ) } -//ToDo refactor key to enum ordinal -enum class ServerSetupLaunchSource(val key: Int) { - SPLASH(0), - SETTINGS(1); +enum class ServerSetupLaunchSource { + SPLASH, + SETTINGS; companion object { - fun fromKey(key: Int) = entries.firstOrNull { it.key == key } ?: SPLASH + fun fromKey(key: Int) = entries.firstOrNull { it.ordinal == key } ?: SPLASH } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index 5643baec..cd7624a6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -34,6 +34,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi +import com.shifthackz.aisdv1.presentation.widget.source.getName @Composable fun ConfigurationModeButton( @@ -80,15 +81,7 @@ fun ConfigurationModeButton( modifier = Modifier .align(Alignment.CenterVertically) .padding(top = 8.dp, bottom = 8.dp), - text = stringResource(id = when (mode) { - ServerSource.AUTOMATIC1111 -> R.string.srv_type_own - ServerSource.HORDE -> R.string.srv_type_horde - ServerSource.LOCAL -> R.string.srv_type_local - ServerSource.HUGGING_FACE -> R.string.srv_type_hugging_face - ServerSource.OPEN_AI -> R.string.srv_type_open_ai - ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai - ServerSource.SWARM_UI -> R.string.srv_type_swarm_ui - }), + text = mode.getName(), textAlign = TextAlign.Center, style = MaterialTheme.typography.bodyLarge, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt new file mode 100644 index 00000000..3638bb78 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.presentation.widget.source + +import androidx.compose.runtime.Composable +import androidx.compose.ui.res.stringResource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.R + +@Composable +fun ServerSource.getName(): String { + return stringResource(id = when (this) { + ServerSource.AUTOMATIC1111 -> R.string.srv_type_own + ServerSource.HORDE -> R.string.srv_type_horde + ServerSource.LOCAL -> R.string.srv_type_local + ServerSource.HUGGING_FACE -> R.string.srv_type_hugging_face + ServerSource.OPEN_AI -> R.string.srv_type_open_ai + ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai + ServerSource.SWARM_UI -> R.string.srv_type_swarm_ui + }) +} diff --git a/presentation/src/main/res/values-ru/strings.xml b/presentation/src/main/res/values-ru/strings.xml index d1b243c1..8908b816 100644 --- a/presentation/src/main/res/values-ru/strings.xml +++ b/presentation/src/main/res/values-ru/strings.xml @@ -96,10 +96,10 @@ Укажите свой URL-адрес Stable Diffusion WebUI Веб-интерфейс для Stable Diffusion, реализованный с использованием библиотеки Gradio. - Примеры URL-адресов сервера:\nhttp://192.168.0.2:7860\nhttp://yourdomain.com:7860\nhttps://yourdomain.com + Примеры URL-адресов сервера:\nhttp://192.168.0.2:%1$s\nhttp://yourdomain.com:%1$s\nhttps://yourdomain.com Этот режим позволяет проверить поведение программы, даже если у вас нет сервера Stable Diffusion WebUI.\n\nВ демонстрационном режиме программа игнорирует параметры генерации, не использует сервер искусственного интеллекта и возвращает фиктивные изображения. Перед подключением убедитесь, что:\n• вы используете AUTOMATIC1111 WebUI с аргументами --api --listen\n• ваш брандмауэр не блокирует порт 7860\n• телефон подключен к одной сети Wi-Fi с ПК - Перед подключением убедитесь, что:\n• ваш брандмауэр не блокирует порт 7860\n• телефон подключен к одной сети Wi-Fi с ПК + Перед подключением убедитесь, что:\n• ваш брандмауэр не блокирует порт 7801\n• телефон подключен к одной сети Wi-Fi с ПК Подключиться к Horde AI Horde AI – это краудсорсинговый распределенный кластер нод генерации изображений и текста. @@ -147,7 +147,7 @@ Выберите SD ML модель Очистить кэш приложения Дебаг Меню - Лора + ЛоРА Инверсия текста Инверсия Редактор тега @@ -265,9 +265,9 @@ Внести битый Base64 в БД Здесь пусто… - Добавьте что-нибуть на AUTOMATIC1111 сервере в: \n\n../models/Lora - Добавьте что-нибуть на AUTOMATIC1111 сервере в: \n\n../models/hypernetworks - Добавьте что-нибуть на AUTOMATIC1111 сервере в: \n\n../embeddings + Добавьте что-нибуть на %1$s сервере в: \n\n%2$s + Добавьте что-нибуть на %1$s сервере в: \n\n%2$s + Добавьте что-нибуть на %1$s сервере в: \n\n%2$s Зарисовка Регулировка diff --git a/presentation/src/main/res/values-tr/strings.xml b/presentation/src/main/res/values-tr/strings.xml index 429c7598..e8e48a1b 100644 --- a/presentation/src/main/res/values-tr/strings.xml +++ b/presentation/src/main/res/values-tr/strings.xml @@ -96,10 +96,10 @@ Lütfen Stable Diffusion WebUI(AUTOMATIC1111) URL adresinizi yazın. Gradio kütüphanesi kullanılarak uygulanan Stable Diffusion için bir web arayüzü. - Bazı sunucu örnekleri:\n• http://192.168.0.2:7860\n• http://alanadiniz.com:7860\n• https://alanadiniz.com + Bazı sunucu örnekleri:\n• http://192.168.0.2:%1$s\n• http://alanadiniz.com:%1$s\n• https://alanadiniz.com Bu mod, Stable Diffusion WebUI sunucunuz olmasa bile uygulama davranışını test etmenize olanak tanır.\n\nDemo modunda uygulama kullanıcı istemini görmezden gelir, AI sunucusunu kullanmaz ve bazı sahte görüntüler döndürür. Bağlanmadan önce şunlardan emin olun:\n• AUTOMATIC1111 WebUI\'yi --api --listen bayraklarıyla çalıştırıyorsunuz\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. - Bağlanmadan önce şunlardan emin olun:\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. + Bağlanmadan önce şunlardan emin olun:\n• Güvenlik duvarınız 7801 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. Horde AI bulutuna bağlanın Horde AI, Görüntü oluşturma çalışanları ve metin oluşturma çalışanlarından oluşan kitle kaynaklı dağıtılmış bir kümedir. @@ -147,7 +147,7 @@ SD Modeli seçin Uygulama önbelleğini temizle Hata Ayıklama Menüsü - Lora + LoRA Metin İnversiyon İnversiyon Etiketi düzenle @@ -265,9 +265,9 @@ Kötü Base64\'ü DB\'ye yerleştirin Burada hiçbir şey… - AUTOMATIC1111 sunucusuna biraz içerik ekleyin: \n\n../models/Lora - AUTOMATIC1111 sunucusuna biraz içerik ekleyin: \n\n../models/hypernetworks - AUTOMATIC1111 sunucusuna biraz içerik ekleyin: \n\n../embeddings + %1$s sunucusuna biraz içerik ekleyin: \n\n%2$s + %1$s sunucusuna biraz içerik ekleyin: \n\n%2$s + %1$s sunucusuna biraz içerik ekleyin: \n\n%2$s Çizmek Ayarlamak diff --git a/presentation/src/main/res/values-uk/strings.xml b/presentation/src/main/res/values-uk/strings.xml index ccb95b7b..50cddb64 100644 --- a/presentation/src/main/res/values-uk/strings.xml +++ b/presentation/src/main/res/values-uk/strings.xml @@ -96,10 +96,10 @@ Вкажіть свою URL-адресу Stable Diffusion WebUI Веб-інтерфейс для Stable Diffusion, реалізований за допомогою бібліотеки Gradio. - Ось приклади URL-адрес сервера:\nhttp://192.168.0.2:7860\nhttp://yourdomain.com:7860\nhttps://yourdomain.com + Ось приклади URL-адрес сервера:\nhttp://192.168.0.2:%1$s\nhttp://yourdomain.com:%1$s\nhttps://yourdomain.com Цей режим дозволяє перевірити поведінку програми, навіть якщо у вас немає сервера Stable Diffusion WebUI.\n\nУ демонстраційному режимі програма ігнорує параметри генерації, не використовує сервер штучного інтелекту та повертає фіктивні зображення. Перед підключенням переконайтеся, що:\n• ви використовуєте AUTOMATIC1111 WebUI з аргументами --api --listen\n• ваш брандмауер не блокує порт 7860\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК - Перед підключенням переконайтеся, що:\n• ваш брандмауер не блокує порт 7860\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК + Перед підключенням переконайтеся, що:\n• ваш брандмауер не блокує порт 7801\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК Підключитися до Horde AI Horde AI — це краудсорсинговий розподілений кластер нод генерації зображень і тексту. @@ -147,7 +147,7 @@ Оберіть SD ML модель Очистити кеш додатку Дебаг Меню - Лора + ЛоРА Інверсія тексту Інверсія Редактор тегу @@ -265,9 +265,9 @@ Внести битий Base64 в БД Тут порожньо… - Додайте щось на AUTOMATIC1111 сервері до: \n\n../models/Lora - Додайте щось на AUTOMATIC1111 сервері до: \n\n../models/hypernetworks - Додайте щось на AUTOMATIC1111 сервері до: \n\n../embeddings + Додайте щось на %1$s сервері до: \n\n%2$s + Додайте щось на %1$s сервері до: \n\n%2$s + Додайте щось на %1$s сервері до: \n\n%2$s Креслення Регулювання diff --git a/presentation/src/main/res/values-zh/strings.xml b/presentation/src/main/res/values-zh/strings.xml index 5b127047..76e03036 100644 --- a/presentation/src/main/res/values-zh/strings.xml +++ b/presentation/src/main/res/values-zh/strings.xml @@ -123,10 +123,10 @@ 提供您的Stable Diffusion WebUI URL 使用 Gradio 库实现的稳定扩散的 Web 界面。 - 以下是服务器URL的示例:\n• http://192.168.0.2:7860\n• http://yourdomain.com:7860\n• https://yourdomain.com + 以下是服务器URL的示例:\n• http://192.168.0.2:%1$s\n• http://yourdomain.com:%1$s\n• https://yourdomain.com 此模式允许您测试应用程序的行为,即使您没有Stable Diffusion WebUI服务器。\n\n在演示模式下,应用程序忽略用户提示,不使用AI服务器,并返回一些模拟图像。 在连接之前确保:\n• 您正在运行AUTOMATIC1111 WebUI并带有标志 --api --listen\n• 您的防火墙没有阻止7860端口\n• 手机与您的PC在同一WiFi下 - 在连接之前确保:\n• 您的防火墙没有阻止7860端口\n• 手机与您的PC在同一WiFi下 + 在连接之前确保:\n• 您的防火墙没有阻止7801端口\n• 手机与您的PC在同一WiFi下 连接到Horde AI云 @@ -184,7 +184,7 @@ 选择SD ML模型 清除应用缓存 调试菜单 - Lora + LoRA 超网络 H-Net 文本反转 @@ -327,9 +327,9 @@ 这里什么都没有… - 在AUTOMATIC1111服务器上添加一些内容到:\n\n../models/Lora - 在AUTOMATIC1111服务器上添加一些内容到:\n\n../models/hypernetworks - 在AUTOMATIC1111服务器上添加一些内容到:\n\n../embeddings + 在%1$s服务器上添加一些内容到:\n\n%2$s + 在%1$s服务器上添加一些内容到:\n\n%2$s + 在%1$s服务器上添加一些内容到:\n\n%2$s 修复 diff --git a/presentation/src/main/res/values/strings.xml b/presentation/src/main/res/values/strings.xml index 3c4d3fe4..fb8b6821 100755 --- a/presentation/src/main/res/values/strings.xml +++ b/presentation/src/main/res/values/strings.xml @@ -165,7 +165,7 @@ Select SD ML Model Clear app cache Debug Menu - Lora + LoRA Hypernetworks H-Net Textual Inversion @@ -287,9 +287,9 @@ Insert bad Base64 in DB Nothing here… - Add some content on AUTOMATIC1111 server to: \n\n../models/Lora - Add some content on AUTOMATIC1111 server to: \n\n../models/hypernetworks - Add some content on AUTOMATIC1111 server to: \n\n../embeddings + Add some content on %1$s server to: \n\n%2$s + Add some content on %1$s server to: \n\n%2$s + Add some content on %1$s server to: \n\n.%2$s Inpaint Draw diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt index 78d27d5b..602363bd 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.presentation.modal.embedding +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest import com.shifthackz.aisdv1.presentation.mocks.mockEmbeddings @@ -12,18 +14,30 @@ import io.reactivex.rxjava3.core.Single import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.test.runTest import org.junit.Assert +import org.junit.Before import org.junit.Test class EmbeddingViewModelTest : CoreViewModelTest() { private val stubException = Throwable("Something went wrong.") private val stubFetchAndGetEmbeddingsUseCase = mockk() + private val stubPreferenceManager = mockk() override fun initializeViewModel() = EmbeddingViewModel( fetchAndGetEmbeddingsUseCase = stubFetchAndGetEmbeddingsUseCase, + preferenceManager = stubPreferenceManager, schedulersProvider = stubSchedulersProvider, ) + @Before + override fun initialize() { + super.initialize() + + every { + stubPreferenceManager.source + } returns ServerSource.AUTOMATIC1111 + } + @Test fun `given update data, fetch embeddings successful, expected UI state with embeddings list`() { every { diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt index 93b872c0..8eed3235 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt @@ -1,6 +1,8 @@ package com.shifthackz.aisdv1.presentation.modal.extras import com.shifthackz.aisdv1.core.common.time.TimeProvider +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest @@ -22,12 +24,14 @@ class ExtrasViewModelTest : CoreViewModelTest() { private val stubException = Throwable("Something went wrong.") private val stubFetchAndGetLorasUseCase = mockk() private val stubFetchAndGetHyperNetworksUseCase = mockk() + private val stubPreferenceManager = mockk() private val stubTimeProvider = mockk() override fun initializeViewModel() = ExtrasViewModel( stubFetchAndGetLorasUseCase, stubFetchAndGetHyperNetworksUseCase, stubSchedulersProvider, + stubPreferenceManager, stubTimeProvider, ) @@ -35,6 +39,10 @@ class ExtrasViewModelTest : CoreViewModelTest() { override fun initialize() { super.initialize() + every { + stubPreferenceManager.source + } returns ServerSource.AUTOMATIC1111 + every { stubTimeProvider.nanoTime() } returns MOCK_SYS_TIME diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt index 891ff322..64527033 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt @@ -58,7 +58,7 @@ class MainRouterImplTest { .assertNoErrors() .assertValueAt(0) { actual -> val expectedRoute = - "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SPLASH.key}" + "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SPLASH.ordinal}" actual is NavigationEffect.Navigate.RouteBuilder && actual.route == expectedRoute } @@ -73,7 +73,7 @@ class MainRouterImplTest { .assertNoErrors() .assertValueAt(0) { actual -> val expectedRoute = - "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SETTINGS.key}" + "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SETTINGS.ordinal}" actual is NavigationEffect.Navigate.RouteBuilder && actual.route == expectedRoute } From 7663735e439e8d78fa4ec343f5f21d47e0412785 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 5 Aug 2024 14:38:13 +0300 Subject: [PATCH 10/11] Fix Gallery detail original img2img base64 output format --- .../aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt | 1 + .../aisdv1/presentation/navigation/graph/MainNavGraph.kt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt index 913561b1..958be2e3 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt @@ -50,6 +50,7 @@ class SwarmUiGenerationRemoteDataSource( request = encodedPayload.mapToSwarmUiRequest(sessionId, model), ) } + .map { (_, outBase64) -> payload to outBase64 } .map(Pair::mapCloudToAiGenResult) private fun generate( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt index 5c0bf6d5..37c57c91 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt @@ -99,4 +99,4 @@ private fun debugMenuTab() = NavItem( icon = NavItem.Icon.Vector( vector = Icons.Default.Deck, ), -) \ No newline at end of file +) From c46a566af6f1492d5f383d9de5dad55068e96db5 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 5 Aug 2024 14:44:44 +0300 Subject: [PATCH 11/11] Finalize implementation --- .../aisdv1/data/di/RemoteDataSourceModule.kt | 2 +- .../data/mappers/ImageToImagePayloadMappers.kt | 1 - .../aisdv1/data/preference/PreferenceManagerImpl.kt | 4 ++-- .../aisdv1/data/preference/SessionPreferenceImpl.kt | 7 ------- .../aisdv1/data/mocks/SwarmUiModelEntityMocks.kt | 10 +++++----- .../aisdv1/data/mocks/SwarmUiModelMocks.kt | 2 +- .../data/preference/PreferenceManagerImplTest.kt | 6 +++--- .../data/preference/SessionPreferenceImplTest.kt | 12 ++++++------ .../aisdv1/domain/preference/PreferenceManager.kt | 2 +- .../aisdv1/domain/preference/SessionPreference.kt | 1 - .../usecase/settings/GetConfigurationUseCaseImpl.kt | 2 +- .../settings/SetServerConfigurationUseCaseImpl.kt | 2 +- .../usecase/splash/SplashNavigationUseCaseImpl.kt | 2 +- .../settings/GetConfigurationUseCaseImplTest.kt | 2 +- .../SetServerConfigurationUseCaseImplTest.kt | 2 +- .../splash/SplashNavigationUseCaseImplTest.kt | 8 ++++---- 16 files changed, 28 insertions(+), 37 deletions(-) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt index 40786808..a420ba6d 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt @@ -62,7 +62,7 @@ val remoteDataSourceModule = module { val chain = if (prefs.source == ServerSource.SWARM_UI) { Single.fromCallable(prefs::swarmUiServerUrl) } else { - Single.fromCallable(prefs::automatic1111serverUrl) + Single.fromCallable(prefs::automatic1111ServerUrl) } chain .map(String::fixUrlSlashes) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt index f764da02..467ca8a3 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt @@ -103,7 +103,6 @@ fun ImageToImagePayload.mapToSwarmUiRequest( SwarmUiGenerationRequest( sessionId = sessionId, model = swarmUiModel, -// initImage = "data:image/png;base64,${base64DefaultToNoWrap(base64Image)}", initImage = base64Image, initImageCreativity = denoisingStrength.roundTo(2).toString(), images = 1, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index ad485239..ae130e47 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -20,7 +20,7 @@ class PreferenceManagerImpl( private val preferencesChangedSubject: BehaviorSubject = BehaviorSubject.createDefault(Unit) - override var automatic1111serverUrl: String + override var automatic1111ServerUrl: String get() = (preferences.getString(KEY_SERVER_URL, "") ?: "").fixUrlSlashes() set(value) = preferences.edit() .putString(KEY_SERVER_URL, value.fixUrlSlashes()) @@ -207,7 +207,7 @@ class PreferenceManagerImpl( .toFlowable(BackpressureStrategy.LATEST) .map { Settings( - serverUrl = automatic1111serverUrl, + serverUrl = automatic1111ServerUrl, sdModel = sdModel, demoMode = demoMode, monitorConnectivity = monitorConnectivity, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt index 162016fc..00be950d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt @@ -4,15 +4,8 @@ import com.shifthackz.aisdv1.domain.preference.SessionPreference class SessionPreferenceImpl : SessionPreference { - private var _coinsPerDay: Int = -1 private var _swarmUiSessionId: String = "" - override var coinsPerDay: Int - get() = _coinsPerDay - set(value) { - _coinsPerDay = value - } - override var swarmUiSessionId: String get() = _swarmUiSessionId set(value) { diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt index 128e54f3..1f151317 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt @@ -4,9 +4,9 @@ import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity val mockSwarmUiModelEntities = listOf( SwarmUiModelEntity( - "5598", - "5598", - "5598", - "", - ) + id = "5598", + name = "5598", + title = "5598", + author = "", + ), ) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt index 22d794cd..2917e3bd 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt @@ -7,5 +7,5 @@ val mockSwarmUiModels = listOf( name = "5598", title = "5598", author = "5598", - ) + ), ) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index 6a0948a8..0535d7f8 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -71,14 +71,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) .thenReturn("") - Assert.assertEquals("", preferenceManager.automatic1111serverUrl) + Assert.assertEquals("", preferenceManager.automatic1111ServerUrl) whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) .thenReturn("https://192.168.0.1:7860") - preferenceManager.automatic1111serverUrl = "https://192.168.0.1:7860" + preferenceManager.automatic1111ServerUrl = "https://192.168.0.1:7860" - Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.automatic1111serverUrl) + Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.automatic1111ServerUrl) preferenceManager .observe() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt index 7d6c9df4..2e645a64 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt @@ -8,14 +8,14 @@ class SessionPreferenceImplTest { private val sessionPreference = SessionPreferenceImpl() @Test - fun `given user reads default coinsPerDay value, expected -1`() { - Assert.assertEquals(-1, sessionPreference.coinsPerDay) + fun `given user reads default swarmUiSessionId value, expected empty String`() { + Assert.assertEquals("", sessionPreference.swarmUiSessionId) } @Test - fun `given user reads default coinsPerDay value, then changes it, expected -1, then changed value`() { - Assert.assertEquals(-1, sessionPreference.coinsPerDay) - sessionPreference.coinsPerDay = 5598 - Assert.assertEquals(5598, sessionPreference.coinsPerDay) + fun `given user reads default coinsPerDay value, then changes it, expected empty String, then changed value`() { + Assert.assertEquals("", sessionPreference.swarmUiSessionId) + sessionPreference.swarmUiSessionId = "5598" + Assert.assertEquals("5598", sessionPreference.swarmUiSessionId) } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index c0f79422..499abb42 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.domain.entity.Settings import io.reactivex.rxjava3.core.Flowable interface PreferenceManager { - var automatic1111serverUrl: String + var automatic1111ServerUrl: String var swarmUiServerUrl: String var swarmUiModel: String var demoMode: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt index ab918e77..79aadadb 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt @@ -1,6 +1,5 @@ package com.shifthackz.aisdv1.domain.preference interface SessionPreference { - var coinsPerDay: Int var swarmUiSessionId: String } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index 438ce9b4..a46f4bd7 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -12,7 +12,7 @@ internal class GetConfigurationUseCaseImpl( override fun invoke(): Single = Single.just( Configuration( - serverUrl = preferenceManager.automatic1111serverUrl, + serverUrl = preferenceManager.automatic1111ServerUrl, swarmUiUrl = preferenceManager.swarmUiServerUrl, swarmUiModel = preferenceManager.swarmUiModel, demoMode = preferenceManager.demoMode, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index fab9973c..83ceab34 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -14,7 +14,7 @@ internal class SetServerConfigurationUseCaseImpl( Completable.fromAction { authorizationStore.storeAuthorizationCredentials(configuration.authCredentials) preferenceManager.source = configuration.source - preferenceManager.automatic1111serverUrl = configuration.serverUrl + preferenceManager.automatic1111ServerUrl = configuration.serverUrl preferenceManager.swarmUiServerUrl = configuration.swarmUiUrl preferenceManager.swarmUiModel = configuration.swarmUiModel preferenceManager.demoMode = configuration.demoMode diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt index b550470d..2dbb2831 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt @@ -15,7 +15,7 @@ internal class SplashNavigationUseCaseImpl( Action.LAUNCH_SERVER_SETUP } - preferenceManager.automatic1111serverUrl.isEmpty() + preferenceManager.automatic1111ServerUrl.isEmpty() && preferenceManager.source == ServerSource.AUTOMATIC1111 -> { Action.LAUNCH_SERVER_SETUP } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index 7007cdb9..9ce3e0af 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -25,7 +25,7 @@ class GetConfigurationUseCaseImplTest { } returns AuthorizationCredentials.None every { - stubPreferenceManager::automatic1111serverUrl.get() + stubPreferenceManager::automatic1111ServerUrl.get() } returns mockConfiguration.serverUrl every { diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index ef1bc037..e41c7ab7 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -28,7 +28,7 @@ class SetServerConfigurationUseCaseImplTest { } returns Unit every { - stubPreferenceManager::automatic1111serverUrl.set(any()) + stubPreferenceManager::automatic1111ServerUrl.set(any()) } returns Unit every { diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt index 7419e773..cb76cb84 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt @@ -28,7 +28,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.automatic1111serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("") whenever(stubPreferenceManager.source) @@ -45,7 +45,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.automatic1111serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) @@ -62,7 +62,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.automatic1111serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("") whenever(stubPreferenceManager.source) @@ -79,7 +79,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.automatic1111serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source)