diff --git a/README.md b/README.md index 8c6bce68..5350358a 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Stable Diffusion AI is an easy-to-use app that lets you quickly generate images - Can use server environment powered by [AI Horde](https://stablehorde.net/) (a crowdsourced distributed cluster of Stable Diffusion workers) - Can use server environment powered by [Stable-Diffusion-WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (AUTOMATIC1111) +- Can use server environment powered by [SwarmUI](https://github.com/mcmonkeyprojects/SwarmUI) - Can use server envitonment powered by [Hugging Face Inference API](https://huggingface.co/docs/api-inference/quicktour). - Can use server environment powered by [OpenAI](https://platform.openai.com/docs/api-reference/images) (DALL-E-2, DALL-E-3). - Can use server environment powered by [Stability AI](https://platform.stability.ai/). @@ -70,31 +71,39 @@ You can have it running either on your own hardware with modern GPU from Nvidia If for some reason you have no ability to run your server instance, you can toggle the **Demo mode** switch on server setup page: it will allow you to test the app and get familiar with it, but it will return some mock images instead of AI-generated ones. -### Option 2: Use AI Horde +### Option 2: Use your own SwarmUI instance + +This requires you to have the SwarmUI that is running in server mode. + +You can have it running either on your own hardware with modern GPU from Nvidia or AMD, or running it using Google Colab. + +Please refer to the [SwarmUI documentation](https://github.com/mcmonkeyprojects/SwarmUI?tab=readme-ov-file#swarmui) for installation instructions. + +### Option 3: Use AI Horde [AI Horde](https://stablehorde.net/) is a crowdsourced distributed cluster of Image generation workers and text generation workers. AI Horde requires to use API KEY, this mobile app alows to use either default API KEY (which is "0000000000"), or type your own. You can sign up and get your own AI Horde API KEY [here](https://stablehorde.net/register). -### Option 3: Hugging Face Inference +### Option 4: Hugging Face Inference [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index) allows to test and evaluate, over 150,000 publicly accessible machine learning models, or your own private models, via simple HTTP requests, with fast inference hosted on Hugging Face shared infrastructure. This service is free, but is rate-limited. Hugging Face Inference requires to use API KEY, which can be created in [Hugging Face account settings](https://huggingface.co/settings/tokens). -### Option 4: OpenAI +### Option 5: OpenAI OpenAI provides a service for text to image generation using [DALLE-2](https://openai.com/dall-e-2) or [DALLE-3](https://openai.com/dall-e-3) models. This service is paid. OpenAI requires to use API KEY, which can be created in [OpenAI API Key settings](https://platform.openai.com/api-keys). -### Option 5: StabilityAI +### Option 6: StabilityAI [StabilityAI](https://platform.stability.ai/) is the image generation service provided by DreamStudio. StabilityAI requires to use API KEY, which can be created in [API Keys page](https://platform.stability.ai/account/keys). -### Option 6: Local Diffusion (Beta) +### Option 7: Local Diffusion (Beta) Only **txt2img** mode is supported. diff --git a/app/build.gradle b/app/build.gradle index bc3f0a73..8e0b2216 100755 --- a/app/build.gradle +++ b/app/build.gradle @@ -34,6 +34,7 @@ android { buildConfigField "String", "DONATE_URL", "\"https://www.buymeacoffee.com/shifthackz\"" buildConfigField "String", "GITHUB_SOURCE_URL", "\"https://github.com/ShiftHackZ/Stable-Diffusion-Android\"" buildConfigField "String", "SETUP_INSTRUCTIONS_URL", "\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki\"" + buildConfigField "String", "SWARM_UI_INFO_URL", "\"https://github.com/mcmonkeyprojects/SwarmUI/tree/master/docs\"" resourceConfigurations = ["en", "ru", "uk", "tr", "zh"] } diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 933011dc..6e43a084 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -107,6 +107,7 @@ val providersModule = module { override val donateUrl: String = BuildConfig.DONATE_URL override val gitHubSourceUrl: String = BuildConfig.GITHUB_SOURCE_URL override val setupInstructionsUrl: String = BuildConfig.SETUP_INSTRUCTIONS_URL + override val swarmUiInfoUrl: String = BuildConfig.SWARM_UI_INFO_URL override val demoModeUrl: String = BuildConfig.DEMO_MODE_API_URL } } diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt index bcf5aed0..ac6670be 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt @@ -10,5 +10,6 @@ interface LinksProvider { val donateUrl: String val gitHubSourceUrl: String val setupInstructionsUrl: String + val swarmUiInfoUrl: String val demoModeUrl: String } diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt new file mode 100644 index 00000000..4bc94dde --- /dev/null +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Hexagonal.kt @@ -0,0 +1,15 @@ +package com.shifthackz.aisdv1.core.common.model + +import java.io.Serializable + +data class Hexagonal( + val first: A, + val second: B, + val third: C, + val fourth: D, + val fifth: E, + val sixth: F, +) : Serializable { + + override fun toString(): String = "($first, $second, $third, $fourth, $fifth, $sixth)" +} diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt index 38499203..7234f7fa 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quadruple.kt @@ -2,7 +2,6 @@ package com.shifthackz.aisdv1.core.common.model import java.io.Serializable - data class Quadruple( val first: A, val second: B, diff --git a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt new file mode 100644 index 00000000..27542df2 --- /dev/null +++ b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/Base64EncodingConverter.kt @@ -0,0 +1,36 @@ +package com.shifthackz.aisdv1.core.imageprocessing + +import com.shifthackz.aisdv1.core.common.log.errorLog +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter.Input +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter.Output +import com.shifthackz.aisdv1.core.imageprocessing.contract.RxImageProcessor +import com.shifthackz.aisdv1.core.imageprocessing.utils.base64DefaultToNoWrap +import io.reactivex.rxjava3.core.Scheduler +import io.reactivex.rxjava3.core.Single + +private typealias Base64EncodingProcessor = RxImageProcessor + +class Base64EncodingConverter( + private val processingScheduler: Scheduler, +) : Base64EncodingProcessor { + + override fun invoke(input: Input): Single = Single + .create { emitter -> + convert(input).fold( + onSuccess = emitter::onSuccess, + onFailure = emitter::onError, + ) + } + .onErrorReturn { t -> + errorLog(t) + Output(input.base64) + } + .subscribeOn(processingScheduler) + + private fun convert(input: Input): Result = runCatching { + Output(base64DefaultToNoWrap(input.base64)) + } + + data class Input(val base64: String) + data class Output(val base64: String) +} diff --git a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt index b7436b0a..5608ffab 100644 --- a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt +++ b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/di/ImageProcessingModule.kt @@ -2,6 +2,7 @@ package com.shifthackz.aisdv1.core.imageprocessing.di import android.graphics.BitmapFactory import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter import com.shifthackz.aisdv1.core.imageprocessing.R @@ -20,4 +21,8 @@ val imageProcessingModule = module { factory { BitmapToBase64Converter(get().computation) } + + factory { + Base64EncodingConverter(get().computation) + } } diff --git a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt index 5865d72d..85557818 100644 --- a/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt +++ b/core/imageprocessing/src/main/java/com/shifthackz/aisdv1/core/imageprocessing/utils/Base64ImageUtils.kt @@ -15,3 +15,8 @@ fun bitmapToBase64(bitmap: Bitmap): String { bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream) return Base64.encodeToString(outputStream.toByteArray(), Base64.DEFAULT) } + +fun base64DefaultToNoWrap(base64Default: String): String { + val byteArray = Base64.decode(base64Default, Base64.DEFAULT) + return Base64.encodeToString(byteArray, Base64.NO_WRAP) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt index 97410914..460c294c 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt @@ -3,25 +3,27 @@ package com.shifthackz.aisdv1.data.di import com.shifthackz.aisdv1.data.gateway.DatabaseClearGatewayImpl import com.shifthackz.aisdv1.data.gateway.mediastore.MediaStoreGatewayFactory import com.shifthackz.aisdv1.data.local.DownloadableModelLocalDataSource +import com.shifthackz.aisdv1.data.local.EmbeddingsLocalDataSource import com.shifthackz.aisdv1.data.local.GenerationResultLocalDataSource import com.shifthackz.aisdv1.data.local.HuggingFaceModelsLocalDataSource +import com.shifthackz.aisdv1.data.local.LorasLocalDataSource import com.shifthackz.aisdv1.data.local.ServerConfigurationLocalDataSource import com.shifthackz.aisdv1.data.local.StabilityAiCreditsLocalDataSource -import com.shifthackz.aisdv1.data.local.StableDiffusionEmbeddingsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionHyperNetworksLocalDataSource -import com.shifthackz.aisdv1.data.local.StableDiffusionLorasLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionModelsLocalDataSource import com.shifthackz.aisdv1.data.local.StableDiffusionSamplersLocalDataSource +import com.shifthackz.aisdv1.data.local.SwarmUiModelsLocalDataSource import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource import com.shifthackz.aisdv1.domain.gateway.DatabaseClearGateway import org.koin.android.ext.koin.androidContext import org.koin.core.module.dsl.factoryOf @@ -35,9 +37,10 @@ val localDataSourceModule = module { single { StabilityAiCreditsLocalDataSource() } factoryOf(::StableDiffusionModelsLocalDataSource) bind StableDiffusionModelsDataSource.Local::class factoryOf(::StableDiffusionSamplersLocalDataSource) bind StableDiffusionSamplersDataSource.Local::class - factoryOf(::StableDiffusionLorasLocalDataSource) bind StableDiffusionLorasDataSource.Local::class + factoryOf(::LorasLocalDataSource) bind LorasDataSource.Local::class factoryOf(::StableDiffusionHyperNetworksLocalDataSource) bind StableDiffusionHyperNetworksDataSource.Local::class - factoryOf(::StableDiffusionEmbeddingsLocalDataSource) bind StableDiffusionEmbeddingsDataSource.Local::class + factoryOf(::EmbeddingsLocalDataSource) bind EmbeddingsDataSource.Local::class + factoryOf(::SwarmUiModelsLocalDataSource) bind SwarmUiModelsDataSource.Local::class factoryOf(::ServerConfigurationLocalDataSource) bind ServerConfigurationDataSource.Local::class factoryOf(::GenerationResultLocalDataSource) bind GenerationResultDataSource.Local::class factoryOf(::DownloadableModelLocalDataSource) bind DownloadableModelDataSource.Local::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt index 05d54d9a..a420ba6d 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt @@ -20,22 +20,30 @@ import com.shifthackz.aisdv1.data.remote.StableDiffusionHyperNetworksRemoteDataS import com.shifthackz.aisdv1.data.remote.StableDiffusionLorasRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.StableDiffusionSamplersRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiEmbeddingsRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiGenerationRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiLorasRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiModelsRemoteDataSource +import com.shifthackz.aisdv1.data.remote.SwarmUiSessionDataSourceImpl import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource import com.shifthackz.aisdv1.domain.datasource.OpenAiGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.RandomImageDataSource import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiEnginesDataSource import com.shifthackz.aisdv1.domain.datasource.StabilityAiGenerationDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.gateway.ServerConnectivityGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager @@ -51,8 +59,12 @@ val remoteDataSourceModule = module { single { ServerUrlProvider { endpoint -> val prefs = get() - Single - .fromCallable(prefs::serverUrl) + val chain = if (prefs.source == ServerSource.SWARM_UI) { + Single.fromCallable(prefs::swarmUiServerUrl) + } else { + Single.fromCallable(prefs::automatic1111ServerUrl) + } + chain .map(String::fixUrlSlashes) .map { baseUrl -> "$baseUrl/$endpoint" } } @@ -61,12 +73,17 @@ val remoteDataSourceModule = module { factoryOf(::HordeGenerationRemoteDataSource) bind HordeGenerationDataSource.Remote::class factoryOf(::HuggingFaceGenerationRemoteDataSource) bind HuggingFaceGenerationDataSource.Remote::class factoryOf(::OpenAiGenerationRemoteDataSource) bind OpenAiGenerationDataSource.Remote::class + factoryOf(::SwarmUiSessionDataSourceImpl) bind SwarmUiSessionDataSource::class + factoryOf(::SwarmUiGenerationRemoteDataSource) bind SwarmUiGenerationDataSource.Remote::class + factoryOf(::SwarmUiModelsRemoteDataSource) bind SwarmUiModelsDataSource.Remote::class + factoryOf(::SwarmUiLorasRemoteDataSource) bind LorasDataSource.Remote.SwarmUi::class + factoryOf(::SwarmUiEmbeddingsRemoteDataSource) bind EmbeddingsDataSource.Remote.SwarmUi::class factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class - factoryOf(::StableDiffusionLorasRemoteDataSource) bind StableDiffusionLorasDataSource.Remote::class + factoryOf(::StableDiffusionLorasRemoteDataSource) bind LorasDataSource.Remote.Automatic1111::class factoryOf(::StableDiffusionHyperNetworksRemoteDataSource) bind StableDiffusionHyperNetworksDataSource.Remote::class - factoryOf(::StableDiffusionEmbeddingsRemoteDataSource) bind StableDiffusionEmbeddingsDataSource.Remote::class + factoryOf(::StableDiffusionEmbeddingsRemoteDataSource) bind EmbeddingsDataSource.Remote.Automatic1111::class factoryOf(::ServerConfigurationRemoteDataSource) bind ServerConfigurationDataSource.Remote::class factoryOf(::RandomImageRemoteDataSource) bind RandomImageDataSource.Remote::class factoryOf(::DownloadableModelRemoteDataSource) bind DownloadableModelDataSource.Remote::class @@ -78,7 +95,7 @@ val remoteDataSourceModule = module { factory { val lambda: () -> Boolean = { val prefs = get() - prefs.source == ServerSource.AUTOMATIC1111 + prefs.source == ServerSource.AUTOMATIC1111 || prefs.source == ServerSource.SWARM_UI } val monitor = get { parametersOf(lambda) } ServerConnectivityGatewayImpl(monitor, get()) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index ea0d16ca..d8e40880 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -3,43 +3,47 @@ package com.shifthackz.aisdv1.data.di import android.content.Context import android.os.PowerManager import com.shifthackz.aisdv1.data.repository.DownloadableModelRepositoryImpl +import com.shifthackz.aisdv1.data.repository.EmbeddingsRepositoryImpl import com.shifthackz.aisdv1.data.repository.GenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl import com.shifthackz.aisdv1.data.repository.StabilityAiCreditsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StabilityAiEnginesRepositoryImpl import com.shifthackz.aisdv1.data.repository.StabilityAiGenerationRepositoryImpl -import com.shifthackz.aisdv1.data.repository.StableDiffusionEmbeddingsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionHyperNetworksRepositoryImpl -import com.shifthackz.aisdv1.data.repository.StableDiffusionLorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.StableDiffusionSamplersRepositoryImpl +import com.shifthackz.aisdv1.data.repository.SwarmUiGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.SwarmUiModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.TemporaryGenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.WakeLockRepositoryImpl import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.RandomImageRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiCreditsRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiEnginesRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiModelsRepository import com.shifthackz.aisdv1.domain.repository.TemporaryGenerationResultRepository import com.shifthackz.aisdv1.domain.repository.WakeLockRepository import org.koin.android.ext.koin.androidContext @@ -60,15 +64,17 @@ val repositoryModule = module { factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class + factoryOf(::SwarmUiGenerationRepositoryImpl) bind SwarmUiGenerationRepository::class + factoryOf(::SwarmUiModelsRepositoryImpl) bind SwarmUiModelsRepository::class factoryOf(::StabilityAiGenerationRepositoryImpl) bind StabilityAiGenerationRepository::class factoryOf(::StabilityAiCreditsRepositoryImpl) bind StabilityAiCreditsRepository::class factoryOf(::StabilityAiEnginesRepositoryImpl) bind StabilityAiEnginesRepository::class factoryOf(::StableDiffusionGenerationRepositoryImpl) bind StableDiffusionGenerationRepository::class factoryOf(::StableDiffusionModelsRepositoryImpl) bind StableDiffusionModelsRepository::class factoryOf(::StableDiffusionSamplersRepositoryImpl) bind StableDiffusionSamplersRepository::class - factoryOf(::StableDiffusionLorasRepositoryImpl) bind StableDiffusionLorasRepository::class + factoryOf(::LorasRepositoryImpl) bind LorasRepository::class factoryOf(::StableDiffusionHyperNetworksRepositoryImpl) bind StableDiffusionHyperNetworksRepository::class - factoryOf(::StableDiffusionEmbeddingsRepositoryImpl) bind StableDiffusionEmbeddingsRepository::class + factoryOf(::EmbeddingsRepositoryImpl) bind EmbeddingsRepository::class factoryOf(::ServerConfigurationRepositoryImpl) bind ServerConfigurationRepository::class factoryOf(::GenerationResultRepositoryImpl) bind GenerationResultRepository::class factoryOf(::RandomImageRepositoryImpl) bind RandomImageRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSource.kt similarity index 63% rename from data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSource.kt rename to data/src/main/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSource.kt index f86b3d28..66db4d3c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSource.kt @@ -2,20 +2,20 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionEmbeddingDao import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity -internal class StableDiffusionEmbeddingsLocalDataSource( +internal class EmbeddingsLocalDataSource( private val dao: StableDiffusionEmbeddingDao, -) : StableDiffusionEmbeddingsDataSource.Local { +) : EmbeddingsDataSource.Local { override fun getEmbeddings() = dao .queryAll() .map(List::mapEntityToDomain) - override fun insertEmbeddings(list: List) = dao + override fun insertEmbeddings(list: List) = dao .deleteAll() .andThen(dao.insertList(list.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSource.kt similarity index 64% rename from data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt rename to data/src/main/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSource.kt index e130ade2..e38771e6 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSource.kt @@ -2,20 +2,20 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionLoraDao import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity -internal class StableDiffusionLorasLocalDataSource( +internal class LorasLocalDataSource( private val dao: StableDiffusionLoraDao, -) : StableDiffusionLorasDataSource.Local { +) : LorasDataSource.Local { override fun getLoras() = dao .queryAll() .map(List::mapEntityToDomain) - override fun insertLoras(loras: List) = dao + override fun insertLoras(loras: List) = dao .deleteAll() .andThen(dao.insertList(loras.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt new file mode 100644 index 00000000..a40cec23 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSource.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity +import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiModelsLocalDataSource( + private val dao: SwarmUiModelDao, +) : SwarmUiModelsDataSource.Local { + + override fun getModels(): Single> = dao + .queryAll() + .map(List::mapEntityToDomain) + + override fun insertModels(models: List): Completable = dao + .deleteAll() + .andThen(dao.insertList(models.mapDomainToEntity())) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt index 40c5d114..3ed4ad39 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/HuggingFaceModelMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.HuggingFaceModelRaw import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(HuggingFaceModelRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(HuggingFaceModelRaw::mapRawToCheckpointDomain) -fun HuggingFaceModelRaw.mapRawToDomain(): HuggingFaceModel = with(this) { +fun HuggingFaceModelRaw.mapRawToCheckpointDomain(): HuggingFaceModel = with(this) { HuggingFaceModel( id = id ?: "", name = name ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt index 75f15ac7..467ca8a3 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/ImageToImagePayloadMappers.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.data.mappers +import com.shifthackz.aisdv1.core.common.math.roundTo import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.StabilityAiClipGuidance @@ -7,9 +8,11 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiStylePreset import com.shifthackz.aisdv1.network.request.HordeGenerationAsyncRequest import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest import com.shifthackz.aisdv1.network.request.ImageToImageRequest +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest import com.shifthackz.aisdv1.network.response.SdGenerationResponse import java.util.Date +//region PAYLOAD --> REQUEST fun ImageToImagePayload.mapToRequest(): ImageToImageRequest = with(this) { ImageToImageRequest( initImages = listOf(base64Image), @@ -93,6 +96,30 @@ fun ImageToImagePayload.mapToStabilityAiRequest() = with(this) { } } +fun ImageToImagePayload.mapToSwarmUiRequest( + sessionId: String, + swarmUiModel: String, +): SwarmUiGenerationRequest = with(this) { + SwarmUiGenerationRequest( + sessionId = sessionId, + model = swarmUiModel, + initImage = base64Image, + initImageCreativity = denoisingStrength.roundTo(2).toString(), + images = 1, + prompt = prompt, + negativePrompt = negativePrompt, + width = width, + height = height, + seed = seed.trim().ifEmpty { null }, + variationSeed = subSeed.trim().ifEmpty { null }, + variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), + cfgScale = cfgScale, + steps = samplingSteps, + ) +} +//endregion + +//region RESPONSE --> RESULT fun Pair.mapToAiGenResult(): AiGenerationResult = let { (payload, response) -> AiGenerationResult( @@ -140,3 +167,4 @@ fun Pair.mapCloudToAiGenResult(): AiGenerationResul subSeedStrength = payload.subSeedStrength, ) } +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt index 7b8fbca6..d9261085 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(DownloadableModelResponse::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(DownloadableModelResponse::mapRawToCheckpointDomain) -fun DownloadableModelResponse.mapRawToDomain(): LocalAiModel = with(this) { +fun DownloadableModelResponse.mapRawToCheckpointDomain(): LocalAiModel = with(this) { LocalAiModel( id = id ?: "", name = name ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt index d6ba31f3..44999ca4 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StabilityAiEngineMappers.kt @@ -4,10 +4,10 @@ import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine import com.shifthackz.aisdv1.network.model.StabilityAiEngineRaw //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(StabilityAiEngineRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StabilityAiEngineRaw::mapRawToCheckpointDomain) -fun StabilityAiEngineRaw.mapRawToDomain(): StabilityAiEngine = with(this) { +fun StabilityAiEngineRaw.mapRawToCheckpointDomain(): StabilityAiEngine = with(this) { StabilityAiEngine(id ?: "", name ?: "") } //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt index 554d1efb..8403d436 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionEmbeddingsMappers.kt @@ -1,20 +1,20 @@ package com.shifthackz.aisdv1.data.mappers -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity //region RAW -> DOMAIN -fun SdEmbeddingsResponse.mapRawToDomain(): List = - loaded?.keys?.map(::StableDiffusionEmbedding) ?: emptyList() +fun SdEmbeddingsResponse.mapRawToCheckpointDomain(): List = + loaded?.keys?.map(::Embedding) ?: emptyList() //endregion //region DOMAIN -> ENTITY -fun List.mapDomainToEntity(): List = - map(StableDiffusionEmbedding::mapDomainToEntity) +fun List.mapDomainToEntity(): List = + map(Embedding::mapDomainToEntity) -fun StableDiffusionEmbedding.mapDomainToEntity(): StableDiffusionEmbeddingEntity = with(this) { +fun Embedding.mapDomainToEntity(): StableDiffusionEmbeddingEntity = with(this) { StableDiffusionEmbeddingEntity( id = keyword, keyword = keyword, @@ -23,10 +23,10 @@ fun StableDiffusionEmbedding.mapDomainToEntity(): StableDiffusionEmbeddingEntity //endregion //region ENTITY -> DOMAIN -fun List.mapEntityToDomain(): List = +fun List.mapEntityToDomain(): List = map(StableDiffusionEmbeddingEntity::mapEntityToDomain) -fun StableDiffusionEmbeddingEntity.mapEntityToDomain(): StableDiffusionEmbedding = with(this) { - StableDiffusionEmbedding(keyword) +fun StableDiffusionEmbeddingEntity.mapEntityToDomain(): Embedding = with(this) { + Embedding(keyword) } //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt index 646d4eb7..0690d860 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionHyperNetworksMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.StableDiffusionHyperNetworkRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionHyperNetworkEntity //region RAW -> DOMAIN -fun List.mapRawToDomain(): List = - map(StableDiffusionHyperNetworkRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StableDiffusionHyperNetworkRaw::mapRawToCheckpointDomain) -fun StableDiffusionHyperNetworkRaw.mapRawToDomain(): StableDiffusionHyperNetwork = with(this) { +fun StableDiffusionHyperNetworkRaw.mapRawToCheckpointDomain(): StableDiffusionHyperNetwork = with(this) { StableDiffusionHyperNetwork( name = name ?: "", path = path ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt index f22c2018..ab382998 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionLorasMappers.kt @@ -1,15 +1,15 @@ package com.shifthackz.aisdv1.data.mappers -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.network.model.StableDiffusionLoraRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity //region RAW --> DOMAIN -fun List.mapToDomain(): List = +fun List.mapToDomain(): List = map(StableDiffusionLoraRaw::mapToDomain) -fun StableDiffusionLoraRaw.mapToDomain(): StableDiffusionLora = with(this) { - StableDiffusionLora( +fun StableDiffusionLoraRaw.mapToDomain(): LoRA = with(this) { + LoRA( name = name ?: "", alias = alias ?: "", path = path ?: "", @@ -18,10 +18,10 @@ fun StableDiffusionLoraRaw.mapToDomain(): StableDiffusionLora = with(this) { //endregion //region DOMAIN -> ENTITY -fun List.mapDomainToEntity(): List = - map(StableDiffusionLora::mapDomainToEntity) +fun List.mapDomainToEntity(): List = + map(LoRA::mapDomainToEntity) -fun StableDiffusionLora.mapDomainToEntity(): StableDiffusionLoraEntity = with(this) { +fun LoRA.mapDomainToEntity(): StableDiffusionLoraEntity = with(this) { StableDiffusionLoraEntity( id = name, name = name, @@ -32,11 +32,11 @@ fun StableDiffusionLora.mapDomainToEntity(): StableDiffusionLoraEntity = with(th //endregion //region ENTITY -> DOMAIN -fun List.mapEntityToDomain(): List = +fun List.mapEntityToDomain(): List = map(StableDiffusionLoraEntity::mapEntityToDomain) -fun StableDiffusionLoraEntity.mapEntityToDomain(): StableDiffusionLora = with(this) { - StableDiffusionLora( +fun StableDiffusionLoraEntity.mapEntityToDomain(): LoRA = with(this) { + LoRA( name = name, alias = alias, path = path, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt index 8128cdf6..b9d03c2d 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionModelsMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.StableDiffusionModelRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionModelEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(StableDiffusionModelRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StableDiffusionModelRaw::mapRawToCheckpointDomain) -fun StableDiffusionModelRaw.mapRawToDomain(): StableDiffusionModel = with(this) { +fun StableDiffusionModelRaw.mapRawToCheckpointDomain(): StableDiffusionModel = with(this) { StableDiffusionModel( title = title ?: "", modelName = modelName ?: "", diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt index af87956c..dd932fa8 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/StableDiffusionSamplersMappers.kt @@ -5,10 +5,10 @@ import com.shifthackz.aisdv1.network.model.StableDiffusionSamplerRaw import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntity //region RAW --> DOMAIN -fun List.mapRawToDomain(): List = - map(StableDiffusionSamplerRaw::mapRawToDomain) +fun List.mapRawToCheckpointDomain(): List = + map(StableDiffusionSamplerRaw::mapRawToCheckpointDomain) -fun StableDiffusionSamplerRaw.mapRawToDomain(): StableDiffusionSampler = with(this) { +fun StableDiffusionSamplerRaw.mapRawToCheckpointDomain(): StableDiffusionSampler = with(this) { StableDiffusionSampler( name = name ?: "", aliases = aliases ?: emptyList(), diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt new file mode 100644 index 00000000..931fa4d0 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/SwarmUiModelsMappers.kt @@ -0,0 +1,77 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity + +//region RAW --> CHECKPOINT DOMAIN +fun SwarmUiModelsResponse.mapRawToCheckpointDomain(): List = with(this) { + this.files?.mapRawToCheckpointDomain() ?: emptyList() +} + +fun List.mapRawToCheckpointDomain(): List = map(SwarmUiModelRaw::mapRawToCheckpointDomain) + +fun SwarmUiModelRaw.mapRawToCheckpointDomain(): SwarmUiModel = with(this) { + SwarmUiModel( + name = name ?: "", + title = title ?: "", + author = author ?: "", + ) +} +//endregion + +//region RAW --> LORA DOMAIN +fun SwarmUiModelsResponse.mapRawToLoraDomain(): List = with(this) { + this.files?.mapRawToLoraDomain() ?: emptyList() +} + +fun List.mapRawToLoraDomain(): List = map(SwarmUiModelRaw::mapRawToLoraDomain) + +fun SwarmUiModelRaw.mapRawToLoraDomain(): LoRA = with(this) { + LoRA( + name = name ?: "", + alias = title ?: "", + path = "", + ) +} +//endregion + +//region RAW -> EMBEDDING DOMAIN +fun SwarmUiModelsResponse.mapRawToEmbeddingDomain(): List = with(this) { + this.files?.mapRawToEmbeddingDomain() ?: emptyList() +} + +fun List.mapRawToEmbeddingDomain(): List = map(SwarmUiModelRaw::mapRawToEmbeddingDomain) + +fun SwarmUiModelRaw.mapRawToEmbeddingDomain(): Embedding = with(this) { + Embedding(title ?: "") +} +//endregion + +//region CHECKPOINT DOMAIN --> ENTITY +fun List.mapDomainToEntity(): List = map(SwarmUiModel::mapDomainToEntity) + +fun SwarmUiModel.mapDomainToEntity(): SwarmUiModelEntity = with(this) { + SwarmUiModelEntity( + id = "${name}_${title}", + name = name, + title = title, + author = author, + ) +} +//endregion + +//region ENTITY --> CHECKPOINT DOMAIN +fun List.mapEntityToDomain(): List = map(SwarmUiModelEntity::mapEntityToDomain) + +fun SwarmUiModelEntity.mapEntityToDomain(): SwarmUiModel = with(this) { + SwarmUiModel( + name = name, + title = title, + author = author, + ) +} +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt index 37dba58a..36c76b21 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt @@ -10,10 +10,12 @@ import com.shifthackz.aisdv1.network.request.HordeGenerationAsyncRequest import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest import com.shifthackz.aisdv1.network.request.OpenAiRequest import com.shifthackz.aisdv1.network.request.StabilityTextToImageRequest +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest import com.shifthackz.aisdv1.network.request.TextToImageRequest import com.shifthackz.aisdv1.network.response.SdGenerationResponse import java.util.Date +//region PAYLOAD --> REQUEST fun TextToImagePayload.mapToRequest(): TextToImageRequest = with(this) { TextToImageRequest( prompt = prompt, @@ -95,6 +97,30 @@ fun TextToImagePayload.mapToStabilityAiRequest(): StabilityTextToImageRequest = ) } +fun TextToImagePayload.mapToSwarmUiRequest( + sessionId: String, + swarmUiModel: String, +): SwarmUiGenerationRequest = with(this) { + SwarmUiGenerationRequest( + sessionId = sessionId, + model = swarmUiModel, + initImage = null, + initImageCreativity = null, + images = 1, + prompt = prompt, + negativePrompt = negativePrompt, + width = width, + height = height, + seed = seed.trim().ifEmpty { null }, + variationSeed = subSeed.trim().ifEmpty { null }, + variationSeedStrength = subSeedStrength.takeIf { it >= 0.1 }?.toString(), + cfgScale = cfgScale, + steps = samplingSteps, + ) +} +//endregion + +//region RESPONSE --> RESULT fun Pair.mapToAiGenResult(): AiGenerationResult = let { (payload, response) -> AiGenerationResult( @@ -165,3 +191,4 @@ fun Pair.mapLocalDiffusionToAiGenResult(): AiGenerat subSeedStrength = payload.subSeedStrength, ) } +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index 20d91c74..ae130e47 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -20,13 +20,28 @@ class PreferenceManagerImpl( private val preferencesChangedSubject: BehaviorSubject = BehaviorSubject.createDefault(Unit) - override var serverUrl: String + override var automatic1111ServerUrl: String get() = (preferences.getString(KEY_SERVER_URL, "") ?: "").fixUrlSlashes() set(value) = preferences.edit() .putString(KEY_SERVER_URL, value.fixUrlSlashes()) .apply() .also { onPreferencesChanged() } + override var swarmUiServerUrl: String + get() = (preferences.getString(KEY_SWARM_SERVER_URL, "") ?: "").fixUrlSlashes() + set(value) = preferences.edit() + .putString(KEY_SWARM_SERVER_URL, value.fixUrlSlashes()) + .apply() + .also { onPreferencesChanged() } + + override var swarmUiModel: String + get() = preferences.getString(KEY_SWARM_MODEL, "") ?: "" + set(value) = preferences + .edit() + .putString(KEY_SWARM_MODEL, value) + .apply() + .also { onPreferencesChanged() } + override var demoMode: Boolean get() = preferences.getBoolean(KEY_DEMO_MODE, false) set(value) = preferences.edit() @@ -192,7 +207,7 @@ class PreferenceManagerImpl( .toFlowable(BackpressureStrategy.LATEST) .map { Settings( - serverUrl = serverUrl, + serverUrl = automatic1111ServerUrl, sdModel = sdModel, demoMode = demoMode, monitorConnectivity = monitorConnectivity, @@ -215,6 +230,8 @@ class PreferenceManagerImpl( companion object { const val KEY_SERVER_URL = "key_server_url" + const val KEY_SWARM_SERVER_URL = "key_swarm_server_url" + const val KEY_SWARM_MODEL = "key_swarm_model" const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" const val KEY_AI_AUTO_SAVE = "key_ai_auto_save" diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt index 2fb82077..00be950d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImpl.kt @@ -4,11 +4,11 @@ import com.shifthackz.aisdv1.domain.preference.SessionPreference class SessionPreferenceImpl : SessionPreference { - private var _coinsPerDay: Int = -1 + private var _swarmUiSessionId: String = "" - override var coinsPerDay: Int - get() = _coinsPerDay + override var swarmUiSessionId: String + get() = _swarmUiSessionId set(value) { - _coinsPerDay = value + _swarmUiSessionId = value } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index 9241dcc6..a2e6b193 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.file.unzip -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi @@ -18,7 +18,7 @@ internal class DownloadableModelRemoteDataSource( override fun fetch() = api .fetchDownloadableModels() - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) override fun download(id: String, url: String): Observable = Completable diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt index a336f341..73056260 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel import com.shifthackz.aisdv1.network.api.sdai.HuggingFaceModelsApi @@ -12,6 +12,6 @@ internal class HuggingFaceModelsRemoteDataSource( override fun fetchHuggingFaceModels() = api .fetchHuggingFaceModels() - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) .onErrorReturn { listOf(HuggingFaceModel.default) } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt index 7fdcd173..b6d68121 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.StabilityAiEnginesDataSource import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi @@ -13,5 +13,5 @@ internal class StabilityAiEnginesRemoteDataSource( override fun fetch(): Single> = api .fetchEngines() - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt index 0da9eb8b..ae45d962 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSource.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi.Companion.PATH_EMBEDDINGS import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse @@ -10,9 +10,9 @@ import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse internal class StableDiffusionEmbeddingsRemoteDataSource( private val serverUrlProvider: ServerUrlProvider, private val api: Automatic1111RestApi, -) : StableDiffusionEmbeddingsDataSource.Remote { +) : EmbeddingsDataSource.Remote.Automatic1111 { override fun fetchEmbeddings() = serverUrlProvider(PATH_EMBEDDINGS) .flatMap(api::fetchEmbeddings) - .map(SdEmbeddingsResponse::mapRawToDomain) + .map(SdEmbeddingsResponse::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt index c72a3beb..059d8dd7 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionHyperNetworksRemoteDataSource( override fun fetchHyperNetworks() = serverUrlProvider(PATH_HYPER_NETWORKS) .flatMap(api::fetchHyperNetworks) - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt index 7019e6ed..dd0d546d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSource.kt @@ -2,8 +2,8 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mappers.mapToDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi.Companion.PATH_LORAS import com.shifthackz.aisdv1.network.model.StableDiffusionLoraRaw @@ -12,9 +12,9 @@ import io.reactivex.rxjava3.core.Single internal class StableDiffusionLorasRemoteDataSource( private val serverUrlProvider: ServerUrlProvider, private val api: Automatic1111RestApi, -) : StableDiffusionLorasDataSource.Remote { +) : LorasDataSource.Remote.Automatic1111 { - override fun fetchLoras(): Single> = serverUrlProvider(PATH_LORAS) + override fun fetchLoras(): Single> = serverUrlProvider(PATH_LORAS) .flatMap(api::fetchLoras) .map(List::mapToDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt index e3a36fea..46485857 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionModelsRemoteDataSource( override fun fetchSdModels() = serverUrlProvider(PATH_SD_MODELS) .flatMap(api::fetchSdModels) - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt index ad1200ec..44056fa7 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSource.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.data.remote -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi @@ -14,5 +14,5 @@ internal class StableDiffusionSamplersRemoteDataSource( override fun fetchSamplers() = serverUrlProvider(PATH_SAMPLERS) .flatMap(api::fetchSamplers) - .map(List::mapRawToDomain) + .map(List::mapRawToCheckpointDomain) } \ No newline at end of file diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt new file mode 100644 index 00000000..6909ff53 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSource.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapRawToEmbeddingDomain +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_MODELS +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.reactivex.rxjava3.core.Single + +class SwarmUiEmbeddingsRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, +) : EmbeddingsDataSource.Remote.SwarmUi { + + override fun fetchEmbeddings(sessionId: String): Single> = serverUrlProvider(PATH_MODELS) + .flatMap { url -> + val request = SwarmUiModelsRequest( + sessionId = sessionId, + subType = "Embedding", + path = "", + depth = 3, + ) + api.fetchModels(url, request) + } + .map(SwarmUiModelsResponse::mapRawToEmbeddingDomain) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt new file mode 100644 index 00000000..958be2e3 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSource.kt @@ -0,0 +1,75 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.mappers.mapCloudToAiGenResult +import com.shifthackz.aisdv1.data.mappers.mapToSwarmUiRequest +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_GENERATE +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import io.reactivex.rxjava3.core.Single + +class SwarmUiGenerationRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, + private val bmpToBase64Converter: BitmapToBase64Converter, + private val base64EncodingConverter: Base64EncodingConverter, +) : SwarmUiGenerationDataSource.Remote { + + override fun textToImage( + sessionId: String, + model: String, + payload: TextToImagePayload + ): Single = + generate( + payload = payload, + request = payload.mapToSwarmUiRequest(sessionId, model), + ) + .map(Pair::mapCloudToAiGenResult) + + override fun imageToImage( + sessionId: String, + model: String, + payload: ImageToImagePayload, + ): Single = payload + .base64Image + .let(Base64EncodingConverter::Input) + .let(base64EncodingConverter::invoke) + .map(Base64EncodingConverter.Output::base64) + .map { base64 -> "data:image/png;base64,${base64}" } + .map { base64Uri -> payload.copy(base64Image = base64Uri) } + .flatMap { encodedPayload -> + generate( + payload = encodedPayload, + request = encodedPayload.mapToSwarmUiRequest(sessionId, model), + ) + } + .map { (_, outBase64) -> payload to outBase64 } + .map(Pair::mapCloudToAiGenResult) + + private fun generate( + payload: T, + request: SwarmUiGenerationRequest, + ): Single> = serverUrlProvider(PATH_GENERATE) + .flatMap { url -> api.generate(url, request) } + .flatMap { response -> + serverUrlProvider("").map { url -> response to url } + } + .flatMap { (response, url) -> + response.images + ?.firstOrNull() + ?.let { endpoint -> Single.just("$url/$endpoint".fixUrlSlashes()) } + ?: Single.error(IllegalStateException("Bad response")) + } + .flatMap(api::downloadImage) + .map(BitmapToBase64Converter::Input) + .flatMap(bmpToBase64Converter::invoke) + .map(BitmapToBase64Converter.Output::base64ImageString) + .map { base64 -> payload to base64 } +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt new file mode 100644 index 00000000..0b6b14df --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSource.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapRawToLoraDomain +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_MODELS +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiLorasRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, +) : LorasDataSource.Remote.SwarmUi { + + override fun fetchLoras(sessionId: String): Single> = serverUrlProvider(PATH_MODELS) + .flatMap { url -> + val request = SwarmUiModelsRequest( + sessionId = sessionId, + subType = "LoRA", + path = "", + depth = 3, + ) + api.fetchModels(url, request) + } + .map(SwarmUiModelsResponse::mapRawToLoraDomain) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt new file mode 100644 index 00000000..c2e57c04 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSource.kt @@ -0,0 +1,30 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_MODELS +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiModelsRemoteDataSource( + private val serverUrlProvider: ServerUrlProvider, + private val api: SwarmUiApi, +) : SwarmUiModelsDataSource.Remote { + + override fun fetchSwarmModels(sessionId: String): Single> = PATH_MODELS + .let(serverUrlProvider::invoke) + .flatMap { url -> + val request = SwarmUiModelsRequest( + sessionId = sessionId, + subType = "Stable-Diffusion", + path = "", + depth = 3, + ) + api.fetchModels(url, request) + } + .map(SwarmUiModelsResponse::mapRawToCheckpointDomain) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt new file mode 100644 index 00000000..8c191c30 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImpl.kt @@ -0,0 +1,51 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.preference.SessionPreference +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi.Companion.PATH_SESSION +import com.shifthackz.aisdv1.network.exception.BadSessionException +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiSessionDataSourceImpl( + private val api: SwarmUiApi, + private val sessionPreference: SessionPreference, + private val serverUrlProvider: ServerUrlProvider, +) : SwarmUiSessionDataSource { + + override fun getSessionId(connectUrl: String?): Single = + if (sessionPreference.swarmUiSessionId.isBlank()) { + forceRenew(connectUrl) + } else { + Single.just(sessionPreference.swarmUiSessionId) + } + + override fun forceRenew(connectUrl: String?): Single { + val chain = connectUrl + ?.let { url -> "$url/$PATH_SESSION".fixUrlSlashes() } + ?.let(api::getNewSession) + ?: serverUrlProvider(PATH_SESSION).flatMap(api::getNewSession) + + return chain + .flatMap { response -> + response.sessionId + ?.takeIf(String::isNotBlank) + ?.let { Single.just(it) } + ?: Single.error(IllegalStateException("Bad session ID.")) + } + .map { sessionId -> + sessionPreference.swarmUiSessionId = sessionId + sessionId + } + } + + override fun handleSessionError(chain: Single): Single = chain.onErrorResumeNext { t -> + if (t is BadSessionException) { + forceRenew().flatMap { chain } + } else { + Single.error(t) + } + } +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt new file mode 100644 index 00000000..d2931538 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImpl.kt @@ -0,0 +1,39 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class EmbeddingsRepositoryImpl( + private val rdsA1111: EmbeddingsDataSource.Remote.Automatic1111, + private val rdsSwarm: EmbeddingsDataSource.Remote.SwarmUi, + private val swarmSession: SwarmUiSessionDataSource, + private val lds: EmbeddingsDataSource.Local, + private val preferenceManager: PreferenceManager, +) : EmbeddingsRepository { + + override fun fetchEmbeddings(): Completable = when (preferenceManager.source) { + ServerSource.AUTOMATIC1111 -> rdsA1111 + .fetchEmbeddings() + .flatMapCompletable(lds::insertEmbeddings) + + ServerSource.SWARM_UI -> swarmSession + .getSessionId() + .flatMap(rdsSwarm::fetchEmbeddings) + .let(swarmSession::handleSessionError) + .flatMapCompletable(lds::insertEmbeddings) + + else -> Completable.complete() + } + + override fun fetchAndGetEmbeddings(): Single> = fetchEmbeddings() + .onErrorComplete() + .andThen(lds.getEmbeddings()) + + override fun getEmbeddings(): Single> = lds.getEmbeddings() +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt new file mode 100644 index 00000000..745657e2 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImpl.kt @@ -0,0 +1,39 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.LorasRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class LorasRepositoryImpl( + private val rdsA1111: LorasDataSource.Remote.Automatic1111, + private val rdsSwarm: LorasDataSource.Remote.SwarmUi, + private val swarmSession: SwarmUiSessionDataSource, + private val lds: LorasDataSource.Local, + private val preferenceManager: PreferenceManager, +) : LorasRepository { + + override fun fetchLoras(): Completable = when (preferenceManager.source) { + ServerSource.AUTOMATIC1111 -> rdsA1111 + .fetchLoras() + .flatMapCompletable(lds::insertLoras) + + ServerSource.SWARM_UI -> swarmSession + .getSessionId() + .flatMap(rdsSwarm::fetchLoras) + .let(swarmSession::handleSessionError) + .flatMapCompletable(lds::insertLoras) + + else -> Completable.complete() + } + + override fun fetchAndGetLoras(): Single> = fetchLoras() + .onErrorComplete() + .andThen(getLoras()) + + override fun getLoras(): Single> = lds.getLoras() +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt deleted file mode 100644 index 1f0dc384..00000000 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImpl.kt +++ /dev/null @@ -1,20 +0,0 @@ -package com.shifthackz.aisdv1.data.repository - -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository - -internal class StableDiffusionEmbeddingsRepositoryImpl( - private val remoteDataSource: StableDiffusionEmbeddingsDataSource.Remote, - private val localDataSource: StableDiffusionEmbeddingsDataSource.Local, -) : StableDiffusionEmbeddingsRepository { - - override fun fetchEmbeddings() = remoteDataSource - .fetchEmbeddings() - .flatMapCompletable(localDataSource::insertEmbeddings) - - override fun fetchAndGetEmbeddings() = fetchEmbeddings() - .onErrorComplete() - .andThen(localDataSource.getEmbeddings()) - - override fun getEmbeddings() = localDataSource.getEmbeddings() -} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt deleted file mode 100644 index 02e4d841..00000000 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImpl.kt +++ /dev/null @@ -1,20 +0,0 @@ -package com.shifthackz.aisdv1.data.repository - -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository - -internal class StableDiffusionLorasRepositoryImpl( - private val remoteDataSource: StableDiffusionLorasDataSource.Remote, - private val localDataSource: StableDiffusionLorasDataSource.Local, -) : StableDiffusionLorasRepository { - - override fun fetchLoras() = remoteDataSource - .fetchLoras() - .flatMapCompletable(localDataSource::insertLoras) - - override fun fetchAndGetLoras() = fetchLoras() - .onErrorComplete() - .andThen(getLoras()) - - override fun getLoras() = localDataSource.getLoras() -} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt new file mode 100644 index 00000000..eccf7001 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImpl.kt @@ -0,0 +1,62 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.core.CoreGenerationRepository +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiGenerationRepositoryImpl( + mediaStoreGateway: MediaStoreGateway, + base64ToBitmapConverter: Base64ToBitmapConverter, + localDataSource: GenerationResultDataSource.Local, + private val preferenceManager: PreferenceManager, + private val session: SwarmUiSessionDataSource, + private val remoteDataSource: SwarmUiGenerationDataSource.Remote, +) : CoreGenerationRepository( + mediaStoreGateway, + base64ToBitmapConverter, + localDataSource, + preferenceManager, +), SwarmUiGenerationRepository { + + override fun checkApiAvailability(): Completable = session + .getSessionId() + .ignoreElement() + + override fun checkApiAvailability(url: String): Completable = session + .getSessionId(url) + .ignoreElement() + + override fun generateFromText(payload: TextToImagePayload): Single = session + .getSessionId() + .flatMap { sessionId -> + remoteDataSource.textToImage( + sessionId = sessionId, + model = preferenceManager.swarmUiModel, + payload = payload, + ) + } + .let(session::handleSessionError) + .flatMap(::insertGenerationResult) + + override fun generateFromImage(payload: ImageToImagePayload): Single = session + .getSessionId() + .flatMap { sessionId -> + remoteDataSource.imageToImage( + sessionId = sessionId, + model = preferenceManager.swarmUiModel, + payload = payload, + ) + } + .let(session::handleSessionError) + .flatMap(::insertGenerationResult) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt new file mode 100644 index 00000000..e2d6ca4c --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImpl.kt @@ -0,0 +1,27 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.domain.repository.SwarmUiModelsRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +internal class SwarmUiModelsRepositoryImpl( + private val session: SwarmUiSessionDataSource, + private val rds: SwarmUiModelsDataSource.Remote, + private val lds: SwarmUiModelsDataSource.Local, +) : SwarmUiModelsRepository { + + override fun fetchModels(): Completable = session + .getSessionId() + .flatMap(rds::fetchSwarmModels) + .let(session::handleSessionError) + .flatMapCompletable(lds::insertModels) + + override fun fetchAndGetModels(): Single> = fetchModels() + .onErrorComplete() + .andThen(getModels()) + + override fun getModels(): Single> = lds.getModels() +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSourceTest.kt similarity index 87% rename from data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSourceTest.kt index 09cc3b27..f5a8c652 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/EmbeddingsLocalDataSourceTest.kt @@ -1,7 +1,7 @@ package com.shifthackz.aisdv1.data.local +import com.shifthackz.aisdv1.data.mocks.mockEmbeddings import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddingEntities -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddings import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionEmbeddingDao import io.mockk.every import io.mockk.mockk @@ -9,12 +9,12 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class StableDiffusionEmbeddingsLocalDataSourceTest { +class EmbeddingsLocalDataSourceTest { private val stubException = Throwable("Database error.") private val stubDao = mockk() - private val localDataSource = StableDiffusionEmbeddingsLocalDataSource(stubDao) + private val localDataSource = EmbeddingsLocalDataSource(stubDao) @Test fun `given attempt to get embeddings, dao returns list, expected valid domain model list value`() { @@ -26,7 +26,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { .getEmbeddings() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } @@ -72,7 +72,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { } returns Completable.complete() localDataSource - .insertEmbeddings(mockStableDiffusionEmbeddings) + .insertEmbeddings(mockEmbeddings) .test() .assertNoErrors() .await() @@ -90,7 +90,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { } returns Completable.complete() localDataSource - .insertEmbeddings(mockStableDiffusionEmbeddings) + .insertEmbeddings(mockEmbeddings) .test() .assertError(stubException) .await() @@ -108,7 +108,7 @@ class StableDiffusionEmbeddingsLocalDataSourceTest { } returns Completable.error(stubException) localDataSource - .insertEmbeddings(mockStableDiffusionEmbeddings) + .insertEmbeddings(mockEmbeddings) .test() .assertError(stubException) .await() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSourceTest.kt similarity index 96% rename from data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSourceTest.kt index 6a7baaa3..20c265eb 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/LorasLocalDataSourceTest.kt @@ -9,12 +9,12 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class StableDiffusionLorasLocalDataSourceTest { +class LorasLocalDataSourceTest { private val stubException = Throwable("Database error.") private val stubDao = mockk() - private val localDataSource = StableDiffusionLorasLocalDataSource(stubDao) + private val localDataSource = LorasLocalDataSource(stubDao) @Test fun `given attempt to get loras, dao returns list, expected valid domain model list value`() { diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt new file mode 100644 index 00000000..34ebbe56 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/SwarmUiModelsLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelEntities +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModels +import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class SwarmUiModelsLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = SwarmUiModelsLocalDataSource(stubDao) + + @Test + fun `given attempt to get models, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockSwarmUiModelEntities) + + localDataSource + .getModels() + .test() + .assertNoErrors() + .assertValue { it.size == mockSwarmUiModels.size } + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert models, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertModels(mockSwarmUiModels) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert models, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertModels(mockSwarmUiModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert models, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertModels(mockSwarmUiModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt index a7286a4e..c4a5e9d2 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.data.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding -val mockStableDiffusionEmbeddings = listOf( - StableDiffusionEmbedding("keyword_5598"), - StableDiffusionEmbedding("keyword_151297"), +val mockEmbeddings = listOf( + Embedding("keyword_5598"), + Embedding("keyword_151297"), ) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt index 07b40191..a6531aa2 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt @@ -1,14 +1,14 @@ package com.shifthackz.aisdv1.data.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA val mockStableDiffusionLoras = listOf( - StableDiffusionLora( + LoRA( name = "name_5598", alias = "alias_5598", path = "/unknown", ), - StableDiffusionLora( + LoRA( name = "name_151297", alias = "alias_151297", path = "/unknown", diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt new file mode 100644 index 00000000..e92ab561 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiGenerationResponseMocks.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse + +val mockSwarmUiGenerationResponse = SwarmUiGenerationResponse( + images = listOf("/tmp/img.jpg"), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt new file mode 100644 index 00000000..1f151317 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelEntityMocks.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity + +val mockSwarmUiModelEntities = listOf( + SwarmUiModelEntity( + id = "5598", + name = "5598", + title = "5598", + author = "", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt new file mode 100644 index 00000000..2917e3bd --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel + +val mockSwarmUiModels = listOf( + SwarmUiModel( + name = "5598", + title = "5598", + author = "5598", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt new file mode 100644 index 00000000..19c9218a --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SwarmUiModelRawMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw + +val mockSwarmUiModelsRaw = listOf( + SwarmUiModelRaw( + name = "5598", + title = "5598", + author = "5598", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index 0da16f96..0535d7f8 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -71,14 +71,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) .thenReturn("") - Assert.assertEquals("", preferenceManager.serverUrl) + Assert.assertEquals("", preferenceManager.automatic1111ServerUrl) whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) .thenReturn("https://192.168.0.1:7860") - preferenceManager.serverUrl = "https://192.168.0.1:7860" + preferenceManager.automatic1111ServerUrl = "https://192.168.0.1:7860" - Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.serverUrl) + Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.automatic1111ServerUrl) preferenceManager .observe() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt index 7d6c9df4..2e645a64 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt @@ -8,14 +8,14 @@ class SessionPreferenceImplTest { private val sessionPreference = SessionPreferenceImpl() @Test - fun `given user reads default coinsPerDay value, expected -1`() { - Assert.assertEquals(-1, sessionPreference.coinsPerDay) + fun `given user reads default swarmUiSessionId value, expected empty String`() { + Assert.assertEquals("", sessionPreference.swarmUiSessionId) } @Test - fun `given user reads default coinsPerDay value, then changes it, expected -1, then changed value`() { - Assert.assertEquals(-1, sessionPreference.coinsPerDay) - sessionPreference.coinsPerDay = 5598 - Assert.assertEquals(5598, sessionPreference.coinsPerDay) + fun `given user reads default coinsPerDay value, then changes it, expected empty String, then changed value`() { + Assert.assertEquals("", sessionPreference.swarmUiSessionId) + sessionPreference.swarmUiSessionId = "5598" + Assert.assertEquals("5598", sessionPreference.swarmUiSessionId) } } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt index c8e7339d..c97b0a9e 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt @@ -3,7 +3,7 @@ package com.shifthackz.aisdv1.data.remote import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor -import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.mocks.mockDownloadableModelsResponse import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import io.reactivex.rxjava3.core.Single @@ -25,7 +25,7 @@ class DownloadableModelRemoteDataSourceTest { whenever(stubApi.fetchDownloadableModels()) .thenReturn(Single.just(mockDownloadableModelsResponse)) - val expected = mockDownloadableModelsResponse.mapRawToDomain() + val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain() remoteDataSource .fetch() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt index bc31e80f..94528711 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt @@ -3,7 +3,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mocks.mockEmptySdEmbeddingsResponse import com.shifthackz.aisdv1.data.mocks.mockSdEmbeddingsResponse import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import io.mockk.every import io.mockk.mockk @@ -39,7 +39,7 @@ class StableDiffusionEmbeddingsRemoteDataSourceTest { .fetchEmbeddings() .test() .assertNoErrors() - .assertValue(listOf(StableDiffusionEmbedding("1504"))) + .assertValue(listOf(Embedding("1504"))) .await() .assertComplete() } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt index 3c999085..a3e129ef 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoraRaw import com.shifthackz.aisdv1.data.provider.ServerUrlProvider -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi import io.mockk.every import io.mockk.mockk @@ -39,7 +39,7 @@ class StableDiffusionLorasRemoteDataSourceTest { .test() .assertNoErrors() .assertValue { loras -> - loras is List + loras is List && loras.size == mockStableDiffusionLoraRaw.size } .await() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt new file mode 100644 index 00000000..d0fbe69d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiEmbeddingsRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelsRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiEmbeddingsRemoteDataSourceTest { + + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = SwarmUiEmbeddingsRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(mockSwarmUiModelsRaw)) + + remoteDataSource + .fetchEmbeddings("5598") + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockSwarmUiModelsRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(emptyList())) + + remoteDataSource + .fetchEmbeddings("5598") + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchEmbeddings("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..29c431f4 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiGenerationRemoteDataSourceTest.kt @@ -0,0 +1,130 @@ +package com.shifthackz.aisdv1.data.remote + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.imageprocessing.Base64EncodingConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiGenerationResponse +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiGenerationRemoteDataSourceTest { + + private val stubBitmap = mockk() + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + private val stubBmpToBase64Converter = mockk() + private val stubBase64EncConverter = mockk() + + private val remoteDataSource = SwarmUiGenerationRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + bmpToBase64Converter = stubBmpToBase64Converter, + base64EncodingConverter = stubBase64EncConverter, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + + @Test + fun `given attempt to generate txt2img, api returns result, expected valid ai generation result value`() { + every { + stubApi.generate(any(), any()) + } returns Single.just(mockSwarmUiGenerationResponse) + + every { + stubApi.downloadImage(any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + remoteDataSource + .textToImage(SESSION_ID, MODEL, mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns error, expected error value`() { + every { + stubApi.generate(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .textToImage(SESSION_ID, MODEL, mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate img2img, api returns result, expected valid ai generation result value`() { + every { + stubBase64EncConverter(any()) + } returns Single.just(Base64EncodingConverter.Output("base64")) + + every { + stubApi.generate(any(), any()) + } returns Single.just(mockSwarmUiGenerationResponse) + + every { + stubApi.downloadImage(any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + remoteDataSource + .imageToImage(SESSION_ID, MODEL, mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate img2img, api returns error, expected error value`() { + every { + stubBase64EncConverter(any()) + } returns Single.error(stubException) + + every { + stubApi.generate(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .imageToImage(SESSION_ID, MODEL, mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + companion object { + private const val SESSION_ID = "5598" + private const val MODEL = "OpenStableDiffusion" + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt new file mode 100644 index 00000000..ed70d23e --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiLorasRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelsRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.LoRA +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiLorasRemoteDataSourceTest { + + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = SwarmUiLorasRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(mockSwarmUiModelsRaw)) + + remoteDataSource + .fetchLoras("5598") + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockSwarmUiModelsRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(emptyList())) + + remoteDataSource + .fetchLoras("5598") + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchLoras("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt new file mode 100644 index 00000000..e4a68d0d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiModelsRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModelsRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiModelsRemoteDataSourceTest { + + private val stubException = Throwable("Something went wrong.") + private val stubServerUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = SwarmUiModelsRemoteDataSource( + serverUrlProvider = stubServerUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(mockSwarmUiModelsRaw)) + + remoteDataSource + .fetchSwarmModels("5598") + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockSwarmUiModelsRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.just(SwarmUiModelsResponse(emptyList())) + + remoteDataSource + .fetchSwarmModels("5598") + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchModels(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchSwarmModels("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt new file mode 100644 index 00000000..6a1a3aa3 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/SwarmUiSessionDataSourceImplTest.kt @@ -0,0 +1,92 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.preference.SessionPreference +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiSessionDataSourceImplTest { + + private val stubApi = mockk() + private val stubSessionPreference = mockk() + private val stubServerUrlProvider = mockk() + + private val remoteDataSource = SwarmUiSessionDataSourceImpl( + api = stubApi, + sessionPreference = stubSessionPreference, + serverUrlProvider = stubServerUrlProvider, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7801") + } + + @Test + fun `given session present in preference, expected sessionId value from preference`() { + every { + stubSessionPreference::swarmUiSessionId.get() + } returns "5598" + + remoteDataSource + .getSessionId() + .test() + .assertNoErrors() + .await() + .assertValue("5598") + .assertComplete() + } + + @Test + fun `given session NOT present in preference, API returns session, expected sessionId value from API`() { + every { + stubSessionPreference::swarmUiSessionId.get() + } returns "" + + every { + stubSessionPreference::swarmUiSessionId.set(any()) + } returns Unit + + every { + stubApi.getNewSession(any()) + } returns Single.just(SwarmUiSessionResponse("5598")) + + remoteDataSource + .getSessionId() + .test() + .assertNoErrors() + .await() + .assertValue("5598") + .assertComplete() + } + + @Test + fun `given session NOT present in preference, API returns null, expected error value`() { + every { + stubSessionPreference::swarmUiSessionId.get() + } returns "" + + every { + stubSessionPreference::swarmUiSessionId.set(any()) + } returns Unit + + every { + stubApi.getNewSession(any()) + } returns Single.just(SwarmUiSessionResponse(null)) + + remoteDataSource + .getSessionId() + .test() + .assertError { t -> t is IllegalStateException && t.message == "Bad session ID." } + .await() + .assertNoValues() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt new file mode 100644 index 00000000..74e044b1 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/EmbeddingsRepositoryImplTest.kt @@ -0,0 +1,417 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockEmbeddings +import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class EmbeddingsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRdsA1111 = mockk() + private val stubRdsSwarm = mockk() + private val stubSwarmSession = mockk() + private val stubLds = mockk() + private val stubPreferenceManager = mockk() + + private val repository = EmbeddingsRepositoryImpl( + rdsA1111 = stubRdsA1111, + rdsSwarm = stubRdsSwarm, + swarmSession = stubSwarmSession, + lds = stubLds, + preferenceManager = stubPreferenceManager, + ) + + @Before + fun initialize() { + every { + stubSwarmSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to fetch embeddings, source is AUTOMATIC1111, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is SWARM_UI, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is AUTOMATIC1111, remote throws exception, local insert success, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is SWARM_UI, remote throws exception, local insert success, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) + } returns Single.error(stubException) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is AUTOMATIC1111, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.error(stubException) + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch embeddings, source is SWARM_UI, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.error(stubException) + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get embeddings, local data source returns list, expected valid domain models list value`() { + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .getEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get embeddings, local data source returns empty list, expected empty domain models list value`() { + every { + stubLds.getEmbeddings() + } returns Single.just(emptyList()) + + repository + .getEmbeddings() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get embeddings, local data source throws exception, expected error value`() { + every { + stubLds.getEmbeddings() + } returns Single.error(stubException) + + repository + .getEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is AUTOMATIC1111, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is SWARM_UI, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchEmbeddings(any()) + } returns Single.just(mockEmbeddings) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is AUTOMATIC1111, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is SWARM_UI, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLds.getEmbeddings() + } returns Single.just(mockEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is AUTOMATIC1111, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.getEmbeddings() + } returns Single.error(stubException) + + repository + .fetchAndGetEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, source is SWARM_UI, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsA1111.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLds.getEmbeddings() + } returns Single.error(stubException) + + repository + .fetchAndGetEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt new file mode 100644 index 00000000..3bff4b37 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LorasRepositoryImplTest.kt @@ -0,0 +1,417 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoras +import com.shifthackz.aisdv1.domain.datasource.LorasDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class LorasRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRdsA1111 = mockk() + private val stubRdsSwarm = mockk() + private val stubSwarmSession = mockk() + private val stubLds = mockk() + private val stubPreferenceManager = mockk() + + private val repository = LorasRepositoryImpl( + rdsA1111 = stubRdsA1111, + rdsSwarm = stubRdsSwarm, + swarmSession = stubSwarmSession, + lds = stubLds, + preferenceManager = stubPreferenceManager, + ) + + @Before + fun initialize() { + every { + stubSwarmSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to fetch loras, source is AUTOMATIC1111, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch loras, source is SWARM_UI, remote returns data, local insert success, expected complete value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch loras, source is AUTOMATIC1111, remote throws exception, local insert success, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.error(stubException) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch loras, source is SWARM_UI, remote throws exception, local insert success, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.error(stubException) + + every { + stubSwarmSession.getSessionId() + } returns Single.error(stubException) + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.error(stubException) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch loras, source is AUTOMATIC1111, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.error(stubException) + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch loras, source is SWARM_UI, remote returns data, local insert fails, expected error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.error(stubException) + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get loras, local data source returns list, expected valid domain models list value`() { + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .getLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get loras, local data source returns empty list, expected empty domain models list value`() { + every { + stubLds.getLoras() + } returns Single.just(emptyList()) + + repository + .getLoras() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get loras, local data source throws exception, expected error value`() { + every { + stubLds.getLoras() + } returns Single.error(stubException) + + repository + .getLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is AUTOMATIC1111, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is SWARM_UI, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is AUTOMATIC1111, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.error(stubException) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is SWARM_UI, remote fails, local returns data, expected valid domain models list value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.error(stubException) + + every { + stubLds.insertLoras(any()) + } returns Completable.complete() + + every { + stubLds.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is AUTOMATIC1111, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.AUTOMATIC1111 + + every { + stubRdsA1111.fetchLoras() + } returns Single.error(stubException) + + every { + stubLds.getLoras() + } returns Single.error(stubException) + + repository + .fetchAndGetLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get loras, source is SWARM_UI, remote fails, local fails, expected valid error value`() { + every { + stubPreferenceManager::source.get() + } returns ServerSource.SWARM_UI + + every { + stubSwarmSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSwarmSession.getSessionId() + } returns Single.just("5598") + + every { + stubRdsSwarm.fetchLoras(any()) + } returns Single.error(stubException) + + every { + stubLds.getLoras() + } returns Single.error(stubException) + + repository + .fetchAndGetLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt deleted file mode 100644 index 933c19c3..00000000 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt +++ /dev/null @@ -1,185 +0,0 @@ -package com.shifthackz.aisdv1.data.repository - -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddings -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource -import io.mockk.every -import io.mockk.mockk -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single -import org.junit.Test - -class StableDiffusionEmbeddingsRepositoryImplTest { - - private val stubException = Throwable("Something went wrong.") - private val stubRemoteDataSource = mockk() - private val stubLocalDataSource = mockk() - - private val repository = StableDiffusionEmbeddingsRepositoryImpl( - remoteDataSource = stubRemoteDataSource, - localDataSource = stubLocalDataSource, - ) - - @Test - fun `given attempt to fetch embeddings, remote returns data, local insert success, expected complete value`() { - every { - stubRemoteDataSource.fetchEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) - - every { - stubLocalDataSource.insertEmbeddings(any()) - } returns Completable.complete() - - repository - .fetchEmbeddings() - .test() - .assertNoErrors() - .await() - .assertComplete() - } - - @Test - fun `given attempt to fetch embeddings, remote throws exception, local insert success, expected error value`() { - every { - stubRemoteDataSource.fetchEmbeddings() - } returns Single.error(stubException) - - every { - stubLocalDataSource.insertEmbeddings(any()) - } returns Completable.complete() - - repository - .fetchEmbeddings() - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to fetch embeddings, remote returns data, local insert fails, expected error value`() { - every { - stubRemoteDataSource.fetchEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) - - every { - stubLocalDataSource.insertEmbeddings(any()) - } returns Completable.error(stubException) - - repository - .fetchEmbeddings() - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to get embeddings, local data source returns list, expected valid domain models list value`() { - every { - stubLocalDataSource.getEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) - - repository - .getEmbeddings() - .test() - .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) - .await() - .assertComplete() - } - - @Test - fun `given attempt to get embeddings, local data source returns empty list, expected empty domain models list value`() { - every { - stubLocalDataSource.getEmbeddings() - } returns Single.just(emptyList()) - - repository - .getEmbeddings() - .test() - .assertNoErrors() - .assertValue(emptyList()) - .await() - .assertComplete() - } - - @Test - fun `given attempt to get embeddings, local data source throws exception, expected error value`() { - every { - stubLocalDataSource.getEmbeddings() - } returns Single.error(stubException) - - repository - .getEmbeddings() - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to fetch and get embeddings, remote returns data, local returns data, expected valid domain models list value`() { - every { - stubRemoteDataSource.fetchEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) - - every { - stubLocalDataSource.insertEmbeddings(any()) - } returns Completable.complete() - - every { - stubLocalDataSource.getEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) - - repository - .fetchAndGetEmbeddings() - .test() - .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) - .await() - .assertComplete() - } - - @Test - fun `given attempt to fetch and get embeddings, remote fails, local returns data, expected valid domain models list value`() { - every { - stubRemoteDataSource.fetchEmbeddings() - } returns Single.error(stubException) - - every { - stubLocalDataSource.insertEmbeddings(any()) - } returns Completable.complete() - - every { - stubLocalDataSource.getEmbeddings() - } returns Single.just(mockStableDiffusionEmbeddings) - - repository - .fetchAndGetEmbeddings() - .test() - .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) - .await() - .assertComplete() - } - - @Test - fun `given attempt to fetch and get embeddings, remote fails, local fails, expected valid error value`() { - every { - stubRemoteDataSource.fetchEmbeddings() - } returns Single.error(stubException) - - every { - stubLocalDataSource.getEmbeddings() - } returns Single.error(stubException) - - repository - .fetchAndGetEmbeddings() - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } -} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt deleted file mode 100644 index 5fbf1aaa..00000000 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt +++ /dev/null @@ -1,186 +0,0 @@ -package com.shifthackz.aisdv1.data.repository - -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoras -import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModels -import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource -import io.mockk.every -import io.mockk.mockk -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single -import org.junit.Test - -class StableDiffusionLorasRepositoryImplTest { - - private val stubException = Throwable("Something went wrong.") - private val stubRemoteDataSource = mockk() - private val stubLocalDataSource = mockk() - - private val repository = StableDiffusionLorasRepositoryImpl( - remoteDataSource = stubRemoteDataSource, - localDataSource = stubLocalDataSource, - ) - - @Test - fun `given attempt to fetch loras, remote returns data, local insert success, expected complete value`() { - every { - stubRemoteDataSource.fetchLoras() - } returns Single.just(mockStableDiffusionLoras) - - every { - stubLocalDataSource.insertLoras(any()) - } returns Completable.complete() - - repository - .fetchLoras() - .test() - .assertNoErrors() - .await() - .assertComplete() - } - - @Test - fun `given attempt to fetch loras, remote throws exception, local insert success, expected error value`() { - every { - stubRemoteDataSource.fetchLoras() - } returns Single.error(stubException) - - every { - stubLocalDataSource.insertLoras(any()) - } returns Completable.complete() - - repository - .fetchLoras() - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to fetch loras, remote returns data, local insert fails, expected error value`() { - every { - stubRemoteDataSource.fetchLoras() - } returns Single.just(mockStableDiffusionLoras) - - every { - stubLocalDataSource.insertLoras(any()) - } returns Completable.error(stubException) - - repository - .fetchLoras() - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to get loras, local data source returns list, expected valid domain models list value`() { - every { - stubLocalDataSource.getLoras() - } returns Single.just(mockStableDiffusionLoras) - - repository - .getLoras() - .test() - .assertNoErrors() - .assertValue(mockStableDiffusionLoras) - .await() - .assertComplete() - } - - @Test - fun `given attempt to get loras, local data source returns empty list, expected empty domain models list value`() { - every { - stubLocalDataSource.getLoras() - } returns Single.just(emptyList()) - - repository - .getLoras() - .test() - .assertNoErrors() - .assertValue(emptyList()) - .await() - .assertComplete() - } - - @Test - fun `given attempt to get loras, local data source throws exception, expected error value`() { - every { - stubLocalDataSource.getLoras() - } returns Single.error(stubException) - - repository - .getLoras() - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to fetch and get loras, remote returns data, local returns data, expected valid domain models list value`() { - every { - stubRemoteDataSource.fetchLoras() - } returns Single.just(mockStableDiffusionLoras) - - every { - stubLocalDataSource.insertLoras(any()) - } returns Completable.complete() - - every { - stubLocalDataSource.getLoras() - } returns Single.just(mockStableDiffusionLoras) - - repository - .fetchAndGetLoras() - .test() - .assertNoErrors() - .assertValue(mockStableDiffusionLoras) - .await() - .assertComplete() - } - - @Test - fun `given attempt to fetch and get loras, remote fails, local returns data, expected valid domain models list value`() { - every { - stubRemoteDataSource.fetchLoras() - } returns Single.error(stubException) - - every { - stubLocalDataSource.insertLoras(any()) - } returns Completable.complete() - - every { - stubLocalDataSource.getLoras() - } returns Single.just(mockStableDiffusionLoras) - - repository - .fetchAndGetLoras() - .test() - .assertNoErrors() - .assertValue(mockStableDiffusionLoras) - .await() - .assertComplete() - } - - @Test - fun `given attempt to fetch and get loras, remote fails, local fails, expected valid error value`() { - every { - stubRemoteDataSource.fetchLoras() - } returns Single.error(stubException) - - every { - stubLocalDataSource.getLoras() - } returns Single.error(stubException) - - repository - .fetchAndGetLoras() - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } -} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..53e09a05 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiGenerationRepositoryImplTest.kt @@ -0,0 +1,211 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiGenerationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubRemoteDataSource = mockk() + private val stubSession = mockk() + private val stubPreferenceManager = mockk() + + private val repository = SwarmUiGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + remoteDataSource = stubRemoteDataSource, + session = stubSession, + preferenceManager = stubPreferenceManager, + ) + + @Before + fun initialize() { + every { + stubPreferenceManager.autoSaveAiResults + } returns false + + every { + stubSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to check api availability, remote completes, expected complete value`() { + every { + stubSession.getSessionId() + } returns Single.just("5598") + + repository + .checkApiAvailability() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to check api availability, remote throws exception, expected error value`() { + every { + stubSession.getSessionId() + } returns Single.error(stubException) + + repository + .checkApiAvailability() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to check api availability by url, remote completes, expected complete value`() { + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + repository + .checkApiAvailability("https://5598.is.my.favourite.com") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to check api availability by url, remote throws exception, expected error value`() { + every { + stubSession.getSessionId(any()) + } returns Single.error(stubException) + + repository + .checkApiAvailability("https://5598.is.my.favourite.com") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, remote returns result, expected valid domain model value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.textToImage(any(), any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, remote throws exception, expected error value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.textToImage(any(), any(), any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, remote returns result, expected valid domain model value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.imageToImage(any(), any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, remote throws exception, expected error value`() { + every { + stubPreferenceManager::swarmUiModel.get() + } returns "5598" + + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubRemoteDataSource.imageToImage(any(), any(), any()) + } returns Single.error(stubException) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt new file mode 100644 index 00000000..ca6cb6cf --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/SwarmUiModelsRepositoryImplTest.kt @@ -0,0 +1,200 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockSwarmUiModels +import com.shifthackz.aisdv1.domain.datasource.SwarmUiModelsDataSource +import com.shifthackz.aisdv1.domain.datasource.SwarmUiSessionDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class SwarmUiModelsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubSession = mockk() + private val stubRds = mockk() + private val stubLds = mockk() + + private val repository = SwarmUiModelsRepositoryImpl(stubSession, stubRds, stubLds) + + @Before + fun initialize() { + every { + stubSession.getSessionId(any()) + } returns Single.just("5598") + + every { + stubSession.getSessionId() + } returns Single.just("5598") + + every { + stubSession.handleSessionError(any>()) + } returnsArgument 0 + } + + @Test + fun `given attempt to fetch models, remote returns data, local insert success, expected complete value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.just(mockSwarmUiModels) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + repository + .fetchModels() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, remote throws exception, local insert success, expected error value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.error(stubException) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + repository + .fetchModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch models, remote returns data, local insert fails, expected error value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.just(mockSwarmUiModels) + + every { + stubLds.insertModels(any()) + } returns Completable.error(stubException) + + repository + .fetchModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get models, local data source returns list, expected valid domain models list value`() { + every { + stubLds.getModels() + } returns Single.just(mockSwarmUiModels) + + repository + .getModels() + .test() + .assertNoErrors() + .assertValue(mockSwarmUiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source returns empty list, expected empty domain models list value`() { + every { + stubLds.getModels() + } returns Single.just(emptyList()) + + repository + .getModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source throws exception, expected error value`() { + every { + stubLds.getModels() + } returns Single.error(stubException) + + repository + .getModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get models, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.just(mockSwarmUiModels) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + every { + stubLds.getModels() + } returns Single.just(mockSwarmUiModels) + + repository + .fetchAndGetModels() + .test() + .assertNoErrors() + .assertValue(mockSwarmUiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.error(stubException) + + every { + stubLds.insertModels(any()) + } returns Completable.complete() + + every { + stubLds.getModels() + } returns Single.just(mockSwarmUiModels) + + repository + .fetchAndGetModels() + .test() + .assertNoErrors() + .assertValue(mockSwarmUiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local fails, expected valid error value`() { + every { + stubRds.fetchSwarmModels(any()) + } returns Single.error(stubException) + + every { + stubLds.getModels() + } returns Single.error(stubException) + + repository + .fetchAndGetModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt new file mode 100644 index 00000000..cadd6b43 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/EmbeddingsDataSource.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.Embedding +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +sealed interface EmbeddingsDataSource { + + interface Remote : EmbeddingsDataSource { + + interface Automatic1111 : Remote { + fun fetchEmbeddings(): Single> + } + + interface SwarmUi : Remote { + fun fetchEmbeddings(sessionId: String): Single> + } + } + + interface Local : EmbeddingsDataSource { + fun getEmbeddings(): Single> + fun insertEmbeddings(list: List): Completable + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt new file mode 100644 index 00000000..fd61ed0f --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/LorasDataSource.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.LoRA +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +sealed interface LorasDataSource { + + sealed interface Remote : LorasDataSource { + + interface Automatic1111 : Remote { + fun fetchLoras(): Single> + } + + interface SwarmUi : Remote { + fun fetchLoras(sessionId: String): Single> + } + } + + interface Local : LorasDataSource { + fun getLoras(): Single> + fun insertLoras(loras: List): Completable + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt deleted file mode 100644 index a1ed1da8..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionEmbeddingsDataSource.kt +++ /dev/null @@ -1,17 +0,0 @@ -package com.shifthackz.aisdv1.domain.datasource - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -sealed interface StableDiffusionEmbeddingsDataSource { - - interface Remote : StableDiffusionEmbeddingsDataSource { - fun fetchEmbeddings(): Single> - } - - interface Local : StableDiffusionEmbeddingsDataSource { - fun getEmbeddings(): Single> - fun insertEmbeddings(list: List): Completable - } -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt deleted file mode 100644 index d8db5e7b..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionLorasDataSource.kt +++ /dev/null @@ -1,18 +0,0 @@ -package com.shifthackz.aisdv1.domain.datasource - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -sealed interface StableDiffusionLorasDataSource { - - interface Remote : StableDiffusionLorasDataSource { - fun fetchLoras(): Single> - } - - interface Local : StableDiffusionLorasDataSource { - fun getLoras(): Single> - - fun insertLoras(loras: List): Completable - } -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt new file mode 100644 index 00000000..88c388a0 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiGenerationDataSource.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +sealed interface SwarmUiGenerationDataSource { + + interface Remote : SwarmUiGenerationDataSource { + fun textToImage( + sessionId: String, + model: String, + payload: TextToImagePayload, + ): Single + + fun imageToImage( + sessionId: String, + model: String, + payload: ImageToImagePayload, + ): Single + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt new file mode 100644 index 00000000..f47f21c7 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiModelsDataSource.kt @@ -0,0 +1,17 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface SwarmUiModelsDataSource { + + interface Remote : SwarmUiModelsDataSource { + fun fetchSwarmModels(sessionId: String): Single> + } + + interface Local : SwarmUiModelsDataSource { + fun getModels(): Single> + fun insertModels(models: List): Completable + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt new file mode 100644 index 00000000..af6807db --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/SwarmUiSessionDataSource.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.datasource + +import io.reactivex.rxjava3.core.Single + +interface SwarmUiSessionDataSource { + fun getSessionId(connectUrl: String? = null): Single + fun forceRenew(connectUrl: String? = null): Single + fun handleSessionError(chain: Single): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index 23338d53..7170f723 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -26,6 +26,8 @@ import com.shifthackz.aisdv1.domain.usecase.connectivity.TestOpenAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestOpenAiApiKeyUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.connectivity.TestStabilityAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestStabilityAiApiKeyUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCase +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCase import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase @@ -86,6 +88,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.SetServerConfigurationUseCase @@ -96,6 +100,8 @@ import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEn import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.stabilityai.ObserveStabilityAiCreditsUseCase import com.shifthackz.aisdv1.domain.usecase.stabilityai.ObserveStabilityAiCreditsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCase import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.wakelock.ReleaseWakeLockUseCase @@ -110,6 +116,7 @@ internal val useCasesModule = module { factoryOf(::PingStableDiffusionServiceUseCaseImpl) bind PingStableDiffusionServiceUseCase::class factoryOf(::ClearAppCacheUseCaseImpl) bind ClearAppCacheUseCase::class factoryOf(::DataPreLoaderUseCaseImpl) bind DataPreLoaderUseCase::class + factoryOf(::FetchAndGetSwarmUiModelsUseCaseImpl) bind FetchAndGetSwarmUiModelsUseCase::class factoryOf(::GetStableDiffusionModelsUseCaseImpl) bind GetStableDiffusionModelsUseCase::class factoryOf(::SelectStableDiffusionModelUseCaseImpl) bind SelectStableDiffusionModelUseCase::class factoryOf(::GetGenerationResultPagedUseCaseImpl) bind GetGenerationResultPagedUseCase::class @@ -128,6 +135,7 @@ internal val useCasesModule = module { factoryOf(::TestHuggingFaceApiKeyUseCaseImpl) bind TestHuggingFaceApiKeyUseCase::class factoryOf(::TestOpenAiApiKeyUseCaseImpl) bind TestOpenAiApiKeyUseCase::class factoryOf(::TestStabilityAiApiKeyUseCaseImpl) bind TestStabilityAiApiKeyUseCase::class + factoryOf(::TestSwarmUiConnectivityUseCaseImpl) bind TestSwarmUiConnectivityUseCase::class factoryOf(::SaveGenerationResultUseCaseImpl) bind SaveGenerationResultUseCase::class factoryOf(::ObserveSeverConnectivityUseCaseImpl) bind ObserveSeverConnectivityUseCase::class factoryOf(::ObserveHordeProcessStatusUseCaseImpl) bind ObserveHordeProcessStatusUseCase::class @@ -146,6 +154,7 @@ internal val useCasesModule = module { factoryOf(::ConnectToHordeUseCaseImpl) bind ConnectToHordeUseCase::class factoryOf(::ConnectToLocalDiffusionUseCaseImpl) bind ConnectToLocalDiffusionUseCase::class factoryOf(::ConnectToA1111UseCaseImpl) bind ConnectToA1111UseCase::class + factoryOf(::ConnectToSwarmUiUseCaseImpl) bind ConnectToSwarmUiUseCase::class factoryOf(::ConnectToHuggingFaceUseCaseImpl) bind ConnectToHuggingFaceUseCase::class factoryOf(::ConnectToOpenAiUseCaseImpl) bind ConnectToOpenAiUseCase::class factoryOf(::ConnectToStabilityAiUseCaseImpl) bind ConnectToStabilityAiUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index 1f2007d9..e644c6f8 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -4,6 +4,8 @@ import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials data class Configuration( val serverUrl: String = "", + val swarmUiUrl: String = "", + val swarmUiModel: String = "", val demoMode: Boolean = false, val source: ServerSource = ServerSource.AUTOMATIC1111, val hordeApiKey: String = "", diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionEmbedding.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Embedding.kt similarity index 66% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionEmbedding.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Embedding.kt index a4faccc8..6e2d1b90 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionEmbedding.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Embedding.kt @@ -1,5 +1,5 @@ package com.shifthackz.aisdv1.domain.entity -data class StableDiffusionEmbedding( +data class Embedding( val keyword: String, ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt index 370da5a3..d4c7db9b 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/FeatureTag.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.entity enum class FeatureTag { Txt2Img, Img2Img, + OwnServer, Lora, TextualInversion, HyperNetworks, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionLora.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LoRA.kt similarity index 78% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionLora.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LoRA.kt index 25906524..54017459 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/StableDiffusionLora.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LoRA.kt @@ -1,6 +1,6 @@ package com.shifthackz.aisdv1.domain.entity -data class StableDiffusionLora( +data class LoRA( val name: String, val alias: String, val path: String, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 067e9e51..4eeca344 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -9,6 +9,7 @@ enum class ServerSource( featureTags = setOf( FeatureTag.Txt2Img, FeatureTag.Img2Img, + FeatureTag.OwnServer, FeatureTag.MultipleModels, FeatureTag.Lora, FeatureTag.TextualInversion, @@ -16,6 +17,18 @@ enum class ServerSource( FeatureTag.Batch, ), ), + SWARM_UI( + key = "swarm_ui", + featureTags = setOf( + FeatureTag.Txt2Img, + FeatureTag.OwnServer, + FeatureTag.Img2Img, + FeatureTag.MultipleModels, + FeatureTag.Lora, + FeatureTag.TextualInversion, + FeatureTag.Batch, + ), + ), HORDE( key = "horde", featureTags = setOf( diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt new file mode 100644 index 00000000..e79e2673 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/SwarmUiModel.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.entity + +class SwarmUiModel( + val name: String, + val title: String, + val author: String, +) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt index fc6dde12..4c42f81c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase interface SetupConnectionInterActor { val connectToHorde: ConnectToHordeUseCase @@ -14,4 +15,5 @@ interface SetupConnectionInterActor { val connectToHuggingFace: ConnectToHuggingFaceUseCase val connectToOpenAi: ConnectToOpenAiUseCase val connectToStabilityAi: ConnectToStabilityAiUseCase + val connectToSwarmUi: ConnectToSwarmUiUseCase } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt index 590ef434..05517631 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase internal data class SetupConnectionInterActorImpl( override val connectToHorde: ConnectToHordeUseCase, @@ -14,4 +15,5 @@ internal data class SetupConnectionInterActorImpl( override val connectToHuggingFace: ConnectToHuggingFaceUseCase, override val connectToOpenAi: ConnectToOpenAiUseCase, override val connectToStabilityAi: ConnectToStabilityAiUseCase, + override val connectToSwarmUi: ConnectToSwarmUiUseCase, ) : SetupConnectionInterActor diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index c37fd154..499abb42 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -5,7 +5,9 @@ import com.shifthackz.aisdv1.domain.entity.Settings import io.reactivex.rxjava3.core.Flowable interface PreferenceManager { - var serverUrl: String + var automatic1111ServerUrl: String + var swarmUiServerUrl: String + var swarmUiModel: String var demoMode: Boolean var monitorConnectivity: Boolean var autoSaveAiResults: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt index 03bf43d1..79aadadb 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/SessionPreference.kt @@ -1,5 +1,5 @@ package com.shifthackz.aisdv1.domain.preference interface SessionPreference { - var coinsPerDay: Int + var swarmUiSessionId: String } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt new file mode 100644 index 00000000..e641225e --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/EmbeddingsRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.Embedding +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface EmbeddingsRepository { + fun fetchEmbeddings(): Completable + fun fetchAndGetEmbeddings(): Single> + fun getEmbeddings(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt new file mode 100644 index 00000000..4beff109 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LorasRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.LoRA +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface LorasRepository { + fun fetchLoras(): Completable + fun fetchAndGetLoras(): Single> + fun getLoras(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt deleted file mode 100644 index 57118629..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionEmbeddingsRepository.kt +++ /dev/null @@ -1,11 +0,0 @@ -package com.shifthackz.aisdv1.domain.repository - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -interface StableDiffusionEmbeddingsRepository { - fun fetchEmbeddings(): Completable - fun fetchAndGetEmbeddings(): Single> - fun getEmbeddings(): Single> -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt deleted file mode 100644 index 9a0d6f03..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/StableDiffusionLorasRepository.kt +++ /dev/null @@ -1,11 +0,0 @@ -package com.shifthackz.aisdv1.domain.repository - -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single - -interface StableDiffusionLorasRepository { - fun fetchLoras(): Completable - fun fetchAndGetLoras(): Single> - fun getLoras(): Single> -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt new file mode 100644 index 00000000..824ccb34 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiGenerationRepository.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface SwarmUiGenerationRepository { + fun checkApiAvailability(): Completable + fun checkApiAvailability(url: String): Completable + fun generateFromText(payload: TextToImagePayload): Single + fun generateFromImage(payload: ImageToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt new file mode 100644 index 00000000..5e97c9ab --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/SwarmUiModelsRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +interface SwarmUiModelsRepository { + fun fetchModels(): Completable + fun fetchAndGetModels(): Single> + fun getModels(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt index c6eb7f77..84a5ffe8 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImpl.kt @@ -1,20 +1,19 @@ package com.shifthackz.aisdv1.domain.usecase.caching +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository +import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository -import io.reactivex.rxjava3.core.Completable internal class DataPreLoaderUseCaseImpl( private val serverConfigurationRepository: ServerConfigurationRepository, private val sdModelsRepository: StableDiffusionModelsRepository, private val sdSamplersRepository: StableDiffusionSamplersRepository, - private val sdLorasRepository: StableDiffusionLorasRepository, + private val sdLorasRepository: LorasRepository, private val sdHyperNetworksRepository: StableDiffusionHyperNetworksRepository, - private val sdEmbeddingsRepository: StableDiffusionEmbeddingsRepository, + private val sdEmbeddingsRepository: EmbeddingsRepository, ) : DataPreLoaderUseCase { override operator fun invoke() = serverConfigurationRepository diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt new file mode 100644 index 00000000..6a29bc0d --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCase.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import io.reactivex.rxjava3.core.Completable + +interface TestSwarmUiConnectivityUseCase { + operator fun invoke(url: String): Completable +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt new file mode 100644 index 00000000..e79cc0e5 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestSwarmUiConnectivityUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository + +class TestSwarmUiConnectivityUseCaseImpl( + private val repository: SwarmUiGenerationRepository, +) : TestSwarmUiConnectivityUseCase { + + override fun invoke(url: String) = repository.checkApiAvailability(url) +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt index 2336f99e..fcade672 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt @@ -7,11 +7,13 @@ import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single internal class ImageToImageUseCaseImpl( private val stableDiffusionGenerationRepository: StableDiffusionGenerationRepository, + private val swarmUiGenerationRepository: SwarmUiGenerationRepository, private val hordeGenerationRepository: HordeGenerationRepository, private val huggingFaceGenerationRepository: HuggingFaceGenerationRepository, private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, @@ -25,6 +27,7 @@ internal class ImageToImageUseCaseImpl( private fun generate(payload: ImageToImagePayload) = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromImage(payload) + ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromImage(payload) ServerSource.HORDE -> hordeGenerationRepository.generateFromImage(payload) ServerSource.HUGGING_FACE -> huggingFaceGenerationRepository.generateFromImage(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromImage(payload) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index 0d6ae856..4874171c 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -9,6 +9,7 @@ import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepositor import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Observable internal class TextToImageUseCaseImpl( @@ -17,6 +18,7 @@ internal class TextToImageUseCaseImpl( private val huggingFaceGenerationRepository: HuggingFaceGenerationRepository, private val openAiGenerationRepository: OpenAiGenerationRepository, private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, + private val swarmUiGenerationRepository: SwarmUiGenerationRepository, private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, private val preferenceManager: PreferenceManager, ) : TextToImageUseCase { @@ -33,5 +35,6 @@ internal class TextToImageUseCaseImpl( ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload) ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromText(payload) + ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromText(payload) } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt index 86f148ad..ca731928 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCase.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.sdembedding -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding import io.reactivex.rxjava3.core.Single interface FetchAndGetEmbeddingsUseCase { - operator fun invoke(): Single> + operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt index 1f6d5c22..257e37e0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImpl.kt @@ -1,10 +1,12 @@ package com.shifthackz.aisdv1.domain.usecase.sdembedding -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository +import com.shifthackz.aisdv1.domain.entity.Embedding +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository +import io.reactivex.rxjava3.core.Single internal class FetchAndGetEmbeddingsUseCaseImpl( - private val repository: StableDiffusionEmbeddingsRepository, + private val repository: EmbeddingsRepository, ) : FetchAndGetEmbeddingsUseCase { - override fun invoke() = repository.fetchAndGetEmbeddings() + override fun invoke(): Single> = repository.fetchAndGetEmbeddings() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt index 820df950..65936385 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCase.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.sdlora -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA import io.reactivex.rxjava3.core.Single interface FetchAndGetLorasUseCase { - operator fun invoke(): Single> + operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt index a6c6d3eb..fd9bc42f 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImpl.kt @@ -1,10 +1,10 @@ package com.shifthackz.aisdv1.domain.usecase.sdlora -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository +import com.shifthackz.aisdv1.domain.repository.LorasRepository internal class FetchAndGetLorasUseCaseImpl( - private val stableDiffusionLorasRepository: StableDiffusionLorasRepository, + private val lorasRepository: LorasRepository, ) : FetchAndGetLorasUseCase { - override fun invoke() = stableDiffusionLorasRepository.getLoras() + override fun invoke() = lorasRepository.fetchAndGetLoras() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt new file mode 100644 index 00000000..7d77f2a6 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import io.reactivex.rxjava3.core.Single + +interface ConnectToSwarmUiUseCase { + operator fun invoke(url: String, credentials: AuthorizationCredentials): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt new file mode 100644 index 00000000..78d771c9 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToSwarmUiUseCaseImpl.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.entity.Configuration +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCase +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import java.util.concurrent.TimeUnit + +internal class ConnectToSwarmUiUseCaseImpl( + private val getConfigurationUseCase: GetConfigurationUseCase, + private val setServerConfigurationUseCase : SetServerConfigurationUseCase, + private val testSwarmUiConnectivityUseCase: TestSwarmUiConnectivityUseCase, +) : ConnectToSwarmUiUseCase { + + override fun invoke( + url: String, + credentials: AuthorizationCredentials, + ): Single> { + var configuration: Configuration? = null + return getConfigurationUseCase() + .map { originalConfiguration -> + configuration = originalConfiguration + originalConfiguration.copy( + source = ServerSource.SWARM_UI, + swarmUiUrl = url, + authCredentials = credentials, + ) + } + .flatMapCompletable(setServerConfigurationUseCase::invoke) + .delay(3L, TimeUnit.SECONDS) + .andThen(testSwarmUiConnectivityUseCase(url)) + .andThen(Single.just(Result.success(Unit))) + .timeout(30L, TimeUnit.SECONDS) + .onErrorResumeNext { t -> + val chain = configuration?.let(setServerConfigurationUseCase::invoke) + ?: Completable.complete() + + chain.andThen(Single.just(Result.failure(t))) + } + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index 67f79ad3..a46f4bd7 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -12,7 +12,9 @@ internal class GetConfigurationUseCaseImpl( override fun invoke(): Single = Single.just( Configuration( - serverUrl = preferenceManager.serverUrl, + serverUrl = preferenceManager.automatic1111ServerUrl, + swarmUiUrl = preferenceManager.swarmUiServerUrl, + swarmUiModel = preferenceManager.swarmUiModel, demoMode = preferenceManager.demoMode, source = preferenceManager.source, hordeApiKey = preferenceManager.hordeApiKey, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index 01ad9c66..83ceab34 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -14,7 +14,9 @@ internal class SetServerConfigurationUseCaseImpl( Completable.fromAction { authorizationStore.storeAuthorizationCredentials(configuration.authCredentials) preferenceManager.source = configuration.source - preferenceManager.serverUrl = configuration.serverUrl + preferenceManager.automatic1111ServerUrl = configuration.serverUrl + preferenceManager.swarmUiServerUrl = configuration.swarmUiUrl + preferenceManager.swarmUiModel = configuration.swarmUiModel preferenceManager.demoMode = configuration.demoMode preferenceManager.hordeApiKey = configuration.hordeApiKey preferenceManager.openAiApiKey = configuration.openAiApiKey diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt index f50bfc26..2dbb2831 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt @@ -15,7 +15,7 @@ internal class SplashNavigationUseCaseImpl( Action.LAUNCH_SERVER_SETUP } - preferenceManager.serverUrl.isEmpty() + preferenceManager.automatic1111ServerUrl.isEmpty() && preferenceManager.source == ServerSource.AUTOMATIC1111 -> { Action.LAUNCH_SERVER_SETUP } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt index 93670eaa..5a17f6f6 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImpl.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.domain.usecase.stabilityai +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.StabilityAiEnginesRepository import io.reactivex.rxjava3.core.Single @@ -12,7 +13,7 @@ internal class FetchAndGetStabilityAiEnginesUseCaseImpl( override fun invoke() = repository .fetchAndGet() .flatMap { engines -> - if (!engines.map { it.id }.contains(preferenceManager.stabilityAiEngineId)) { + if (!engines.map(StabilityAiEngine::id).contains(preferenceManager.stabilityAiEngineId)) { preferenceManager.stabilityAiEngineId = engines.firstOrNull()?.id ?: "" } Single.just(engines) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt new file mode 100644 index 00000000..35b555fb --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.swarmmodel + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import io.reactivex.rxjava3.core.Single + +interface FetchAndGetSwarmUiModelsUseCase { + operator fun invoke(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt new file mode 100644 index 00000000..556067ab --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/swarmmodel/FetchAndGetSwarmUiModelsUseCaseImpl.kt @@ -0,0 +1,21 @@ +package com.shifthackz.aisdv1.domain.usecase.swarmmodel + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.SwarmUiModelsRepository +import io.reactivex.rxjava3.core.Single + +internal class FetchAndGetSwarmUiModelsUseCaseImpl( + private val preferenceManager: PreferenceManager, + private val repository: SwarmUiModelsRepository, +) : FetchAndGetSwarmUiModelsUseCase { + + override fun invoke(): Single> = repository + .fetchAndGetModels() + .map { models -> + if (!models.map(SwarmUiModel::name).contains(preferenceManager.swarmUiModel)) { + preferenceManager.swarmUiModel = models.firstOrNull()?.name ?: "" + } + models + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt index 84e9f9da..114e97b4 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -4,6 +4,8 @@ import com.shifthackz.aisdv1.domain.entity.Configuration val mockConfiguration = Configuration( serverUrl = "http://5598.is.my.favorite.com", + swarmUiUrl = "http://5598.is.my.favorite.com", + swarmUiModel = "5598", hordeApiKey = "5598", openAiApiKey = "5598", huggingFaceApiKey = "5598", diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt new file mode 100644 index 00000000..7e457a43 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LoraMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.LoRA + +val mockLoRAs = listOf( + LoRA( + name = "5598", + alias = "5598", + path = "/unknown", + ), + LoRA( + name = "151297", + alias = "151297", + path = "/unknown", + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt index 8cf54d7e..3a44c632 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt @@ -1,9 +1,9 @@ package com.shifthackz.aisdv1.domain.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding -val mockStableDiffusionEmbeddings = listOf( - StableDiffusionEmbedding("embedding_1"), - StableDiffusionEmbedding("embedding_2"), - StableDiffusionEmbedding("embedding_3"), +val mockEmbeddings = listOf( + Embedding("embedding_1"), + Embedding("embedding_2"), + Embedding("embedding_3"), ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt index 372373b5..9a920900 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt @@ -2,10 +2,10 @@ package com.shifthackz.aisdv1.domain.usecase.caching import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository +import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository -import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository import io.reactivex.rxjava3.core.Completable @@ -16,17 +16,17 @@ class DataPreLoaderUseCaseImplTest { private val stubServerConfigurationRepository = mock() private val stubStableDiffusionModelsRepository = mock() private val stubStableDiffusionSamplersRepository = mock() - private val stubStableDiffusionLorasRepository = mock() + private val stubLorasRepository = mock() private val stubStableDiffusionHyperNetworksRepository = mock() - private val stubStableDiffusionEmbeddingsRepository = mock() + private val stubEmbeddingsRepository = mock() private val useCase = DataPreLoaderUseCaseImpl( serverConfigurationRepository = stubServerConfigurationRepository, sdModelsRepository = stubStableDiffusionModelsRepository, sdSamplersRepository = stubStableDiffusionSamplersRepository, - sdLorasRepository = stubStableDiffusionLorasRepository, + sdLorasRepository = stubLorasRepository, sdHyperNetworksRepository = stubStableDiffusionHyperNetworksRepository, - sdEmbeddingsRepository = stubStableDiffusionEmbeddingsRepository, + sdEmbeddingsRepository = stubEmbeddingsRepository, ) @Test @@ -40,13 +40,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -69,13 +69,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -98,13 +98,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -127,13 +127,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.error(stubException)) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -156,13 +156,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.error(stubException)) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -185,13 +185,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.error(stubException)) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.complete()) useCase() @@ -214,13 +214,13 @@ class DataPreLoaderUseCaseImplTest { whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionLorasRepository.fetchLoras()) + whenever(stubLorasRepository.fetchLoras()) .thenReturn(Completable.complete()) whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) .thenReturn(Completable.complete()) - whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + whenever(stubEmbeddingsRepository.fetchEmbeddings()) .thenReturn(Completable.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt index 2a1544b3..9f2fac3b 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Single import org.junit.Test @@ -18,6 +19,7 @@ class ImageToImageUseCaseImplTest { private val stubException = Throwable("Unable to generate image.") private val stubStableDiffusionGenerationRepository = mock() + private val stubSwarmUiGenerationRepository = mock() private val stubHordeGenerationRepository = mock() private val stubHuggingFaceGenerationRepository = mock() private val stubStabilityAiGenerationRepository = mock() @@ -25,6 +27,7 @@ class ImageToImageUseCaseImplTest { private val useCase = ImageToImageUseCaseImpl( stableDiffusionGenerationRepository = stubStableDiffusionGenerationRepository, + swarmUiGenerationRepository = stubSwarmUiGenerationRepository, hordeGenerationRepository = stubHordeGenerationRepository, huggingFaceGenerationRepository = stubHuggingFaceGenerationRepository, stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index 1283e409..b6d3bced 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -13,6 +13,7 @@ import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepositor import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Single import org.junit.Test @@ -24,6 +25,7 @@ class TextToImageUseCaseImplTest { private val stubHuggingFaceGenerationRepository = mock() private val stubOpenAiGenerationRepository = mock() private val stubStabilityAiGenerationRepository = mock() + private val stubSwarmUiGenerationRepository = mock() private val stubLocalDiffusionGenerationRepository = mock() private val stubPreferenceManager = mock() @@ -34,6 +36,7 @@ class TextToImageUseCaseImplTest { openAiGenerationRepository = stubOpenAiGenerationRepository, stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, + swarmUiGenerationRepository = stubSwarmUiGenerationRepository, preferenceManager = stubPreferenceManager, ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt index 28952d80..bc806e01 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt @@ -3,26 +3,26 @@ package com.shifthackz.aisdv1.domain.usecase.sdembedding import com.nhaarman.mockitokotlin2.doReturn import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever -import com.shifthackz.aisdv1.domain.mocks.mockStableDiffusionEmbeddings -import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository +import com.shifthackz.aisdv1.domain.mocks.mockEmbeddings +import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import io.reactivex.rxjava3.core.Single import org.junit.Test class FetchAndGetEmbeddingsUseCaseImplTest { - private val stubRepository = mock() + private val stubRepository = mock() private val useCase = FetchAndGetEmbeddingsUseCaseImpl(stubRepository) @Test fun `given repository provided embeddings list, expected valid list value`() { whenever(stubRepository.fetchAndGetEmbeddings()) - .doReturn(Single.just(mockStableDiffusionEmbeddings)) + .doReturn(Single.just(mockEmbeddings)) useCase() .test() .assertNoErrors() - .assertValue(mockStableDiffusionEmbeddings) + .assertValue(mockEmbeddings) .await() .assertComplete() } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt new file mode 100644 index 00000000..643bf9f2 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdlora/FetchAndGetLorasUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.sdlora + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockLoRAs +import com.shifthackz.aisdv1.domain.repository.LorasRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class FetchAndGetLorasUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = FetchAndGetLorasUseCaseImpl(stubRepository) + + @Test + fun `given repository provided list of LoRAs, expected valid list value`() { + whenever(stubRepository.fetchAndGetLoras()) + .thenReturn(Single.just(mockLoRAs)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockLoRAs) + .await() + .assertComplete() + } + + @Test + fun `given repository provided empty list of LoRAs, expected empty list value`() { + whenever(stubRepository.fetchAndGetLoras()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Unknown error occurred.") + + whenever(stubRepository.fetchAndGetLoras()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index c23282ba..9ce3e0af 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -25,9 +25,17 @@ class GetConfigurationUseCaseImplTest { } returns AuthorizationCredentials.None every { - stubPreferenceManager::serverUrl.get() + stubPreferenceManager::automatic1111ServerUrl.get() } returns mockConfiguration.serverUrl + every { + stubPreferenceManager::swarmUiServerUrl.get() + } returns mockConfiguration.swarmUiUrl + + every { + stubPreferenceManager::swarmUiModel.get() + } returns mockConfiguration.swarmUiModel + every { stubPreferenceManager::demoMode.get() } returns mockConfiguration.demoMode diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index fabc2fe4..e41c7ab7 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -28,7 +28,15 @@ class SetServerConfigurationUseCaseImplTest { } returns Unit every { - stubPreferenceManager::serverUrl.set(any()) + stubPreferenceManager::automatic1111ServerUrl.set(any()) + } returns Unit + + every { + stubPreferenceManager::swarmUiModel.set(any()) + } returns Unit + + every { + stubPreferenceManager::swarmUiServerUrl.set(any()) } returns Unit every { diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt index c2787eeb..cb76cb84 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt @@ -28,7 +28,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("") whenever(stubPreferenceManager.source) @@ -45,7 +45,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) @@ -62,7 +62,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("") whenever(stubPreferenceManager.source) @@ -79,7 +79,7 @@ class SplashNavigationUseCaseImplTest { whenever(stubPreferenceManager.forceSetupAfterUpdate) .thenReturn(false) - whenever(stubPreferenceManager.serverUrl) + whenever(stubPreferenceManager.automatic1111ServerUrl) .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt new file mode 100644 index 00000000..0f261edd --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApi.kt @@ -0,0 +1,64 @@ +package com.shifthackz.aisdv1.network.api.swarmui + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse +import io.reactivex.rxjava3.core.Single +import okhttp3.ResponseBody +import retrofit2.Response +import retrofit2.http.Body +import retrofit2.http.GET +import retrofit2.http.POST +import retrofit2.http.Streaming +import retrofit2.http.Url + +interface SwarmUiApi { + + fun getNewSession(url: String): Single + + fun generate( + @Url url: String, + @Body request: SwarmUiGenerationRequest, + ): Single + + fun fetchModels( + @Url url: String, + @Body request: SwarmUiModelsRequest, + ): Single + + fun downloadImage(url: String): Single + + interface RawApi { + + @POST + fun getNewSession( + @Url url: String, + @Body map: Map, + ): Single + + @POST + fun generate( + @Url url: String, + @Body request: SwarmUiGenerationRequest, + ): Single + + @POST + fun fetchModels( + @Url url: String, + @Body request: SwarmUiModelsRequest, + ): Single + + @Streaming + @GET + fun download(@Url url: String): Single> + } + + companion object { + const val PATH_SESSION = "API/GetNewSession" + const val PATH_GENERATE = "API/GenerateText2Image" + const val PATH_MODELS = "API/ListModels" + } +} diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt new file mode 100644 index 00000000..84a94a64 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/swarmui/SwarmUiApiImpl.kt @@ -0,0 +1,54 @@ +package com.shifthackz.aisdv1.network.api.swarmui + +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import com.shifthackz.aisdv1.network.exception.BadSessionException +import com.shifthackz.aisdv1.network.request.SwarmUiGenerationRequest +import com.shifthackz.aisdv1.network.request.SwarmUiModelsRequest +import com.shifthackz.aisdv1.network.response.SwarmUiGenerationResponse +import com.shifthackz.aisdv1.network.response.SwarmUiModelsResponse +import com.shifthackz.aisdv1.network.response.SwarmUiSessionResponse +import io.reactivex.rxjava3.core.Single +import retrofit2.HttpException + +internal class SwarmUiApiImpl( + private val rawApi: SwarmUiApi.RawApi, +) : SwarmUiApi { + + override fun getNewSession(url: String): Single = rawApi + .getNewSession(url, emptyMap()) + .mapError() + + override fun generate( + url: String, + request: SwarmUiGenerationRequest, + ): Single = rawApi + .generate(url, request) + .mapError() + + override fun fetchModels( + url: String, + request: SwarmUiModelsRequest + ): Single = rawApi + .fetchModels(url, request) + .mapError() + + override fun downloadImage(url: String): Single = rawApi + .download(url) + .mapError() + .flatMap { response -> + response.body() + ?.bytes() + ?.let { BitmapFactory.decodeByteArray(it, 0, it.size) } + ?.let { Single.just(it) } + ?: Single.error(Throwable("Body is null")) + } + + private fun Single.mapError(): Single = this.onErrorResumeNext { t -> + if (t is HttpException && t.code() == 401) { + Single.error(BadSessionException()) + } else { + Single.error(t) + } + } +} diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt b/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt index 0995039c..bf8c4867 100755 --- a/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt @@ -14,6 +14,8 @@ import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApiImpl import com.shifthackz.aisdv1.network.api.sdai.HuggingFaceModelsApi import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApi +import com.shifthackz.aisdv1.network.api.swarmui.SwarmUiApiImpl import com.shifthackz.aisdv1.network.authenticator.RestAuthenticator import com.shifthackz.aisdv1.network.connectivity.ConnectivityMonitor import com.shifthackz.aisdv1.network.error.StabilityAiErrorMapper @@ -105,6 +107,12 @@ val networkModule = module { .create(Automatic1111RestApi::class.java) } + single { + get() + .withBaseUrl(get().stableDiffusionAutomaticApiUrl) + .create(SwarmUiApi.RawApi::class.java) + } + single { get() .withBaseUrl(get().hordeApiUrl) @@ -156,6 +164,7 @@ val networkModule = module { singleOf(::ImageCdnRestApiImpl) bind ImageCdnRestApi::class singleOf(::DownloadableModelsApiImpl) bind DownloadableModelsApi::class singleOf(::HuggingFaceInferenceApiImpl) bind HuggingFaceInferenceApi::class + singleOf(::SwarmUiApiImpl) bind SwarmUiApi::class factory { params -> ConnectivityMonitor(params.get()) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt b/network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt new file mode 100644 index 00000000..c0589290 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/exception/BadSessionException.kt @@ -0,0 +1,3 @@ +package com.shifthackz.aisdv1.network.exception + +class BadSessionException : Throwable() diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt b/network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt new file mode 100644 index 00000000..47f5fc53 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/model/SwarmUiModelRaw.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.network.model + +import com.google.gson.annotations.SerializedName + +data class SwarmUiModelRaw( + @SerializedName("name") + val name: String?, + @SerializedName("title") + val title: String?, + @SerializedName("author") + val author: String?, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt new file mode 100644 index 00000000..7bd6f84c --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiGenerationRequest.kt @@ -0,0 +1,39 @@ +package com.shifthackz.aisdv1.network.request + +import com.google.gson.annotations.SerializedName + +data class SwarmUiGenerationRequest( + @SerializedName("session_id") + val sessionId: String, + @SerializedName("model") + val model: String, + @SerializedName("initimage") + val initImage: String?, + @SerializedName("initimagecreativity") + val initImageCreativity: String?, + @SerializedName("images") + val images: Int, + @SerializedName("prompt") + val prompt: String, + @SerializedName("negativeprompt") + val negativePrompt: String, + @SerializedName("width") + val width: Int, + @SerializedName("height") + val height: Int, + @SerializedName("seed") + val seed: String?, + @SerializedName("variationseed") + val variationSeed: String?, + @SerializedName("variationseedstrength") + val variationSeedStrength: String?, + @SerializedName("cfgscale") + val cfgScale: Float?, + @SerializedName("steps") + val steps: Int, +// @SerializedName("initimageresettonorm") +// val initimageresettonorm: String = "0", +// @SerializedName("initimagerecompositemask") +// val initimagerecompositemask: String = "0", + +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt new file mode 100644 index 00000000..35461f67 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/request/SwarmUiModelsRequest.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.network.request + +import com.google.gson.annotations.SerializedName + +data class SwarmUiModelsRequest( + @SerializedName("session_id") + val sessionId: String, + @SerializedName("subtype") + val subType: String, + @SerializedName("path") + val path: String, + @SerializedName("depth") + val depth: Int, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt new file mode 100644 index 00000000..ae8cdd12 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiGenerationResponse.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName + +data class SwarmUiGenerationResponse( + @SerializedName("images") + val images: List?, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt new file mode 100644 index 00000000..4913c5b9 --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiModelsResponse.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName +import com.shifthackz.aisdv1.network.model.SwarmUiModelRaw + +data class SwarmUiModelsResponse( + @SerializedName("files") + val files: List?, +) diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt new file mode 100644 index 00000000..8ec5052b --- /dev/null +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/SwarmUiSessionResponse.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.network.response + +import com.google.gson.annotations.SerializedName + +data class SwarmUiSessionResponse( + @SerializedName("session_id") + val sessionId: String?, +) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt index 16b72387..57746dcc 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingScreen.kt @@ -47,10 +47,12 @@ import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties import com.shifthackz.aisdv1.core.extensions.shimmer import com.shifthackz.aisdv1.core.ui.MviComponent +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.widget.error.ErrorComposable +import com.shifthackz.aisdv1.presentation.widget.source.getName import com.shifthackz.aisdv1.presentation.widget.toolbar.ModalDialogToolbar import org.koin.androidx.compose.koinViewModel @@ -209,7 +211,7 @@ private fun ScreenContent( } else { if (state.embeddings.isEmpty()) { item(key = "empty_state") { - EmbeddingEmptyState() + EmbeddingEmptyState(state.source) } } else { items( @@ -234,7 +236,7 @@ private fun ScreenContent( } @Composable -private fun EmbeddingEmptyState() { +private fun EmbeddingEmptyState(source: ServerSource) { Column( verticalArrangement = Arrangement.Center, ) { @@ -243,12 +245,21 @@ private fun EmbeddingEmptyState() { text = stringResource(id = R.string.extras_empty_title), fontSize = 20.sp, ) + val path = when (source) { + ServerSource.AUTOMATIC1111 -> "./embeddings" + ServerSource.SWARM_UI -> "./Models/Embeddings" + else -> "" + } Text( modifier = Modifier .padding(top = 16.dp) .padding(horizontal = 16.dp) .align(Alignment.CenterHorizontally), - text = stringResource(id = R.string.extras_empty_sub_title_embedding), + text = stringResource( + id = R.string.extras_empty_sub_title_embedding, + source.getName(), + path, + ), textAlign = TextAlign.Center, ) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt index 93338f60..9898c607 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingState.kt @@ -1,12 +1,14 @@ package com.shifthackz.aisdv1.presentation.modal.embedding import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.android.core.mvi.MviState @Immutable data class EmbeddingState( val loading: Boolean = true, + val source: ServerSource = ServerSource.AUTOMATIC1111, val error: ErrorState = ErrorState.None, val prompt: String = "", val negativePrompt: String = "", diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt index 2b284aee..70f75d31 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModel.kt @@ -1,10 +1,10 @@ package com.shifthackz.aisdv1.presentation.modal.embedding -import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect import com.shifthackz.aisdv1.presentation.model.ErrorState @@ -13,11 +13,18 @@ import io.reactivex.rxjava3.kotlin.subscribeBy class EmbeddingViewModel( private val fetchAndGetEmbeddingsUseCase: FetchAndGetEmbeddingsUseCase, + private val preferenceManager: PreferenceManager, private val schedulersProvider: SchedulersProvider, ) : MviRxViewModel() { override val initialState = EmbeddingState() + init { + updateState { + it.copy(source = preferenceManager.source) + } + } + override fun processIntent(intent: EmbeddingIntent) { when (intent) { EmbeddingIntent.ApplyNewPrompts -> emitEffect( @@ -53,7 +60,14 @@ class EmbeddingViewModel( } fun updateData(prompt: String, negativePrompt: String) = !fetchAndGetEmbeddingsUseCase() - .doOnSubscribe { updateState { it.copy(loading = true) } } + .doOnSubscribe { + updateState { state -> + state.copy( + loading = true, + source = preferenceManager.source, + ) + } + } .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = { t -> @@ -61,10 +75,10 @@ class EmbeddingViewModel( updateState { it.copy(loading = false, error = ErrorState.Generic) } }, onSuccess = { embeddings -> - debugLog(embeddings) updateState { state -> state.copy( loading = false, + source = preferenceManager.source, error = ErrorState.None, prompt = prompt, negativePrompt = negativePrompt, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt index 5ae52cee..efee8a6d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasScreen.kt @@ -44,10 +44,12 @@ import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties import com.shifthackz.aisdv1.core.extensions.shimmer import com.shifthackz.aisdv1.core.ui.MviComponent +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.model.ExtraType import com.shifthackz.aisdv1.presentation.widget.error.ErrorComposable +import com.shifthackz.aisdv1.presentation.widget.source.getName import com.shifthackz.aisdv1.presentation.widget.toolbar.ModalDialogToolbar import org.koin.androidx.compose.koinViewModel @@ -164,7 +166,7 @@ private fun ScreenContent( } else { if (state.loras.isEmpty()) { item(key = "empty_state") { - ExtrasEmptyState(type = state.type) + ExtrasEmptyState(state.type, state.source) } } else { items( @@ -188,7 +190,7 @@ private fun ScreenContent( } @Composable -private fun ExtrasEmptyState(type: ExtraType) { +private fun ExtrasEmptyState(type: ExtraType, source: ServerSource) { Column( verticalArrangement = Arrangement.Center, ) { @@ -197,16 +199,30 @@ private fun ExtrasEmptyState(type: ExtraType) { text = stringResource(id = R.string.extras_empty_title), fontSize = 20.sp, ) + val path = when (type) { + ExtraType.Lora -> when (source) { + ServerSource.AUTOMATIC1111 -> "../models/Lora" + ServerSource.SWARM_UI -> "../Models/Lora" + else -> "" + } + ExtraType.HyperNet -> when (source) { + ServerSource.AUTOMATIC1111 -> "../models/hypernetworks" + ServerSource.SWARM_UI -> "" + else -> "" + } + } Text( modifier = Modifier .padding(top = 16.dp) .padding(horizontal = 16.dp) .align(Alignment.CenterHorizontally), text = stringResource( - id = when (type) { + when (type) { ExtraType.Lora -> R.string.extras_empty_sub_title_lora ExtraType.HyperNet -> R.string.extras_empty_sub_title_hypernet - } + }, + source.getName(), + path, ), textAlign = TextAlign.Center, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt index e75460ac..5860bfe0 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasState.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.modal.extras import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.model.ExtraType import com.shifthackz.android.core.mvi.MviState @@ -8,6 +9,7 @@ import com.shifthackz.android.core.mvi.MviState @Immutable data class ExtrasState( val loading: Boolean = true, + val source: ServerSource = ServerSource.AUTOMATIC1111, val error: ErrorState = ErrorState.None, val prompt: String = "", val negativePrompt: String = "", diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt index d5bcd1d8..52f6e3a5 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt @@ -5,8 +5,9 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel +import com.shifthackz.aisdv1.domain.entity.LoRA import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase import com.shifthackz.aisdv1.presentation.model.ErrorState @@ -18,11 +19,18 @@ class ExtrasViewModel( private val fetchAndGetLorasUseCase: FetchAndGetLorasUseCase, private val fetchAndGetHyperNetworksUseCase: FetchAndGetHyperNetworksUseCase, private val schedulersProvider: SchedulersProvider, + private val preferenceManager: PreferenceManager, private val timeProvider: TimeProvider, ) : MviRxViewModel() { override val initialState = ExtrasState() + init { + updateState { + it.copy(source = preferenceManager.source) + } + } + override fun processIntent(intent: ExtrasIntent) { when (intent) { ExtrasIntent.ApplyPrompts -> emitEffect( @@ -53,7 +61,15 @@ class ExtrasViewModel( ExtraType.Lora -> fetchAndGetLorasUseCase() ExtraType.HyperNet -> fetchAndGetHyperNetworksUseCase() } - .doOnSubscribe { updateState { it.copy(loading = true, type = type) } } + .doOnSubscribe { + updateState { state -> + state.copy( + loading = true, + type = type, + source = preferenceManager.source, + ) + } + } .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = { t -> @@ -64,6 +80,7 @@ class ExtrasViewModel( updateState { state -> state.copy( loading = false, + source = preferenceManager.source, error = ErrorState.None, prompt = prompt, negativePrompt = negativePrompt, @@ -72,14 +89,14 @@ class ExtrasViewModel( val (isApplied, value) = ExtrasFormatter.isExtraWithValuePresentInPrompt( prompt = prompt, loraAlias = when (it) { - is StableDiffusionLora -> it.alias + is LoRA -> it.alias is StableDiffusionHyperNetwork -> it.name else -> "" }, type = type, ) when (it) { - is StableDiffusionLora -> ExtraItemUi( + is LoRA -> ExtraItemUi( type = type, key = "${it.name}_${type}_${timeProvider.nanoTime()}", name = it.name, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt index 7ba8081e..37c57c91 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/graph/MainNavGraph.kt @@ -31,7 +31,7 @@ fun NavGraphBuilder.mainNavGraph() { ComposeNavigator.Destination(provider[ComposeNavigator::class]) { entry -> val sourceKey = entry.arguments ?.getInt(Constants.PARAM_SOURCE) - ?: ServerSetupLaunchSource.SPLASH.key + ?: ServerSetupLaunchSource.SPLASH.ordinal ServerSetupScreen(launchSourceKey = sourceKey) }.apply { route = Constants.ROUTE_SERVER_SETUP_FULL @@ -99,4 +99,4 @@ private fun debugMenuTab() = NavItem( icon = NavItem.Icon.Vector( vector = Icons.Default.Deck, ), -) \ No newline at end of file +) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt index 23399f95..4e9978b7 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt @@ -34,7 +34,7 @@ internal class MainRouterImpl( } override fun navigateToServerSetup(source: ServerSetupLaunchSource) { - effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.key}") { + effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.ordinal}") { if (source == ServerSetupLaunchSource.SPLASH) { popUpTo(Constants.ROUTE_SPLASH) { inclusive = true diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt index 566905a1..77e61f25 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt @@ -136,6 +136,7 @@ private fun ScreenContent( content = { paddingValues -> when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HORDE, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE -> { @@ -182,8 +183,7 @@ private fun ScreenContent( } } - ServerSource.OPEN_AI, - ServerSource.LOCAL -> { + else -> { Column( modifier = Modifier .padding(paddingValues) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index 9ad999a3..ec4b47ad 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -107,7 +107,6 @@ fun SettingsScreen() { } } - @Composable private fun ScreenContent( modifier: Modifier = Modifier, @@ -188,6 +187,7 @@ private fun ContentSettingsState( ServerSource.OPEN_AI -> R.string.srv_type_open_ai ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai ServerSource.LOCAL -> R.string.srv_type_local_short + ServerSource.SWARM_UI -> R.string.srv_type_swarm_ui }.asUiText(), onClick = { processIntent(SettingsIntent.NavigateConfiguration) }, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt index 306ca562..1239c701 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt @@ -39,7 +39,7 @@ data class SettingsState( get() = serverSource == ServerSource.AUTOMATIC1111 val showMonitorConnectionOption: Boolean - get() = serverSource == ServerSource.AUTOMATIC1111 + get() = serverSource == ServerSource.AUTOMATIC1111 || serverSource == ServerSource.SWARM_UI val showFormAdvancedOption: Boolean get() = serverSource != ServerSource.OPEN_AI diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt index a367ced4..80d51618 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt @@ -12,6 +12,8 @@ sealed interface ServerSetupIntent : MviIntent { data class UpdateServerUrl(val url: String) : ServerSetupIntent + data class UpdateSwarmUiUrl(val url: String) : ServerSetupIntent + data class UpdateAuthType(val type: ServerSetupState.AuthType) : ServerSetupIntent data class UpdateLogin(val login: String) : ServerSetupIntent @@ -56,7 +58,11 @@ sealed interface ServerSetupIntent : MviIntent { data object A1111Instructions : LaunchUrl() { override val url: String get() = linksProvider.setupInstructionsUrl + } + data object SwarmUiInstructions : LaunchUrl() { + override val url: String + get() = linksProvider.swarmUiInfoUrl } data object HordeInfo : LaunchUrl() { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 58c46cc4..555a4d32 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -21,6 +21,7 @@ data class ServerSetupState( val allowedModes: List = ServerSource.entries, val screenModal: Modal = Modal.None, val serverUrl: String = "", + val swarmUiUrl: String = "", val hordeApiKey: String = "", val huggingFaceApiKey: String = "", val openAiApiKey: String = "", @@ -36,6 +37,7 @@ data class ServerSetupState( val localCustomModel: Boolean = false, val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, + val swarmUiUrlValidationError: UiText? = null, val loginValidationError: UiText? = null, val passwordValidationError: UiText? = null, val hordeApiKeyValidationError: UiText? = null, @@ -83,13 +85,12 @@ data class ServerSetupState( ) } -//ToDo refactor key to enum ordinal -enum class ServerSetupLaunchSource(val key: Int) { - SPLASH(0), - SETTINGS(1); +enum class ServerSetupLaunchSource { + SPLASH, + SETTINGS; companion object { - fun fromKey(key: Int) = entries.firstOrNull { it.key == key } ?: SPLASH + fun fromKey(key: Int) = entries.firstOrNull { it.ordinal == key } ?: SPLASH } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 9d74b8d5..59e5d6e0 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -50,6 +50,18 @@ class ServerSetupViewModel( showBackNavArrow = launchSource == ServerSetupLaunchSource.SETTINGS, ) + private val credentials: AuthorizationCredentials + get() = when (currentState.mode) { + ServerSource.AUTOMATIC1111 -> { + if (!currentState.demoMode) currentState.credentialsDomain() + else AuthorizationCredentials.None + } + + ServerSource.SWARM_UI -> currentState.credentialsDomain() + + else -> AuthorizationCredentials.None + } + private var downloadDisposable: Disposable? = null init { @@ -73,6 +85,7 @@ class ServerSetupViewModel( mode = configuration.source, demoMode = configuration.demoMode, serverUrl = configuration.serverUrl, + swarmUiUrl = configuration.swarmUiUrl, authType = configuration.authType, ) .withCredentials(configuration.authCredentials) @@ -181,6 +194,10 @@ class ServerSetupViewModel( it.copy(serverUrl = intent.url, serverUrlValidationError = null) } + is ServerSetupIntent.UpdateSwarmUiUrl -> updateState { + it.copy(swarmUiUrl = intent.url, swarmUiUrlValidationError = null) + } + is ServerSetupIntent.LaunchUrl -> { emitEffect(ServerSetupEffect.LaunchUrl(intent.url)) } @@ -219,6 +236,7 @@ class ServerSetupViewModel( ServerSource.HUGGING_FACE -> connectToHuggingFace() ServerSource.OPEN_AI -> connectToOpenAi() ServerSource.STABILITY_AI -> connectToStabilityAi() + ServerSource.SWARM_UI -> connectToSwarmUi() } .doOnSubscribe { setScreenModal(Modal.Communicating(canCancel = false)) } .subscribeOnMainThread(schedulersProvider) @@ -236,34 +254,11 @@ class ServerSetupViewModel( private fun validate(): Boolean = when (currentState.mode) { ServerSource.AUTOMATIC1111 -> { if (currentState.demoMode) true - else { - val serverUrlValidation = urlValidator(currentState.serverUrl) - var isValid = serverUrlValidation.isValid - updateState { state -> - var newState = state.copy( - serverUrlValidationError = serverUrlValidation.mapToUi() - ) - if (currentState.authType == ServerSetupState.AuthType.HTTP_BASIC) { - val loginValidation = stringValidator(currentState.login) - val passwordValidation = stringValidator(currentState.password) - newState = newState.copy( - loginValidationError = loginValidation.mapToUi(), - passwordValidationError = passwordValidation.mapToUi() - ) - isValid = isValid && loginValidation.isValid && passwordValidation.isValid - } - if (serverUrlValidation.validationError is UrlValidator.Error.Localhost - && newState.loginValidationError == null - && newState.passwordValidationError == null - ) { - newState = newState.copy(screenModal = Modal.ConnectLocalHost) - } - newState - } - isValid - } + else validateServerUrlAndCredentials(currentState.serverUrl) } + ServerSource.SWARM_UI -> validateServerUrlAndCredentials(currentState.swarmUiUrl) + ServerSource.HORDE -> { if (currentState.hordeDefaultApiKey) true else { @@ -304,37 +299,70 @@ class ServerSetupViewModel( } } + private fun validateServerUrlAndCredentials(url: String): Boolean { + val serverUrlValidation = urlValidator(url) + var isValid = serverUrlValidation.isValid + updateState { state -> + var newState = state.copy( + serverUrlValidationError = if (state.mode == ServerSource.AUTOMATIC1111) { + serverUrlValidation.mapToUi() + } else { + state.serverUrlValidationError + }, + swarmUiUrlValidationError = if (state.mode == ServerSource.SWARM_UI) { + serverUrlValidation.mapToUi() + } else { + state.swarmUiUrlValidationError + }, + ) + if (currentState.authType == ServerSetupState.AuthType.HTTP_BASIC) { + val loginValidation = stringValidator(currentState.login) + val passwordValidation = stringValidator(currentState.password) + newState = newState.copy( + loginValidationError = loginValidation.mapToUi(), + passwordValidationError = passwordValidation.mapToUi() + ) + isValid = isValid && loginValidation.isValid && passwordValidation.isValid + } + if (serverUrlValidation.validationError is UrlValidator.Error.Localhost + && newState.loginValidationError == null + && newState.passwordValidationError == null + ) { + newState = newState.copy(screenModal = Modal.ConnectLocalHost) + } + newState + } + return isValid + } + private fun connectToAutomaticInstance(): Single> { val demoMode = currentState.demoMode val connectUrl = if (demoMode) currentState.demoModeUrl else currentState.serverUrl - val credentials = when (currentState.mode) { - ServerSource.AUTOMATIC1111 -> { - if (!demoMode) currentState.credentialsDomain() - else AuthorizationCredentials.None - } - - else -> AuthorizationCredentials.None - } return setupConnectionInterActor.connectToA1111( - connectUrl, - demoMode, - credentials, + url = connectUrl, + isDemo = demoMode, + credentials = credentials, ) } + private fun connectToSwarmUi() = setupConnectionInterActor.connectToSwarmUi( + url = currentState.swarmUiUrl, + credentials = credentials, + ) + private fun connectToHuggingFace() = with(currentState) { setupConnectionInterActor.connectToHuggingFace( - huggingFaceApiKey, - huggingFaceModel, + apiKey = huggingFaceApiKey, + model = huggingFaceModel, ) } private fun connectToOpenAi() = setupConnectionInterActor.connectToOpenAi( - currentState.openAiApiKey, + apiKey = currentState.openAiApiKey, ) private fun connectToStabilityAi() = setupConnectionInterActor.connectToStabilityAi( - currentState.stabilityAiApiKey, + apiKey = currentState.stabilityAiApiKey, ) private fun connectToHorde(): Single> { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index 0bf7a40b..cd7624a6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -34,6 +34,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi +import com.shifthackz.aisdv1.presentation.widget.source.getName @Composable fun ConfigurationModeButton( @@ -65,7 +66,8 @@ fun ConfigurationModeButton( .size(42.dp) .padding(top = 8.dp, bottom = 8.dp), imageVector = when (mode) { - ServerSource.AUTOMATIC1111 -> Icons.Default.Computer + ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI -> Icons.Default.Computer ServerSource.HORDE, ServerSource.OPEN_AI, ServerSource.STABILITY_AI, @@ -79,25 +81,20 @@ fun ConfigurationModeButton( modifier = Modifier .align(Alignment.CenterVertically) .padding(top = 8.dp, bottom = 8.dp), - text = stringResource(id = when (mode) { - ServerSource.AUTOMATIC1111 -> R.string.srv_type_own - ServerSource.HORDE -> R.string.srv_type_horde - ServerSource.LOCAL -> R.string.srv_type_local - ServerSource.HUGGING_FACE -> R.string.srv_type_hugging_face - ServerSource.OPEN_AI -> R.string.srv_type_open_ai - ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai - }), + text = mode.getName(), textAlign = TextAlign.Center, style = MaterialTheme.typography.bodyLarge, ) } val descriptionId = when (mode) { - ServerSource.AUTOMATIC1111 -> null + ServerSource.AUTOMATIC1111 -> R.string.hint_server_setup_sub_title ServerSource.HORDE -> R.string.hint_server_horde_sub_title ServerSource.HUGGING_FACE -> R.string.hint_hugging_face_sub_title ServerSource.OPEN_AI -> R.string.hint_open_ai_sub_title ServerSource.LOCAL -> R.string.hint_local_diffusion_sub_title ServerSource.STABILITY_AI -> R.string.hint_stability_ai_sub_title + ServerSource.SWARM_UI -> R.string.hint_swarm_ui_sub_title + else -> null } descriptionId?.let { resId -> Text( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt new file mode 100644 index 00000000..9daf1976 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/AuthCredentialsForm.kt @@ -0,0 +1,99 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.forms + +import androidx.compose.foundation.layout.ColumnScope +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Visibility +import androidx.compose.material.icons.filled.VisibilityOff +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.material3.TextField +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.presentation.R +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState +import com.shifthackz.aisdv1.presentation.widget.input.DropdownTextField + +@Composable +fun ColumnScope.AuthCredentialsForm( + state: ServerSetupState, + processIntent: (ServerSetupIntent) -> Unit, + modifier: Modifier, +) { + DropdownTextField( + modifier = modifier, + label = R.string.auth_title.asUiText(), + items = ServerSetupState.AuthType.entries, + value = state.authType, + onItemSelected = { + processIntent(ServerSetupIntent.UpdateAuthType(it)) + }, + displayDelegate = { type -> + when (type) { + ServerSetupState.AuthType.ANONYMOUS -> R.string.auth_anonymous + ServerSetupState.AuthType.HTTP_BASIC -> R.string.auth_http_basic + }.asUiText() + } + ) + when (state.authType) { + ServerSetupState.AuthType.HTTP_BASIC -> { + TextField( + modifier = modifier, + value = state.login, + onValueChange = { + processIntent(ServerSetupIntent.UpdateLogin(it)) + }, + label = { Text(stringResource(id = R.string.hint_login)) }, + isError = state.loginValidationError != null, + supportingText = state.loginValidationError?.let { + { Text(it.asString(), color = MaterialTheme.colorScheme.error) } + }, + maxLines = 1, + ) + TextField( + modifier = modifier, + value = state.password, + onValueChange = { + processIntent(ServerSetupIntent.UpdatePassword(it)) + }, + label = { Text(stringResource(id = R.string.hint_password)) }, + isError = state.passwordValidationError != null, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Password), + visualTransformation = if (state.passwordVisible) { + VisualTransformation.None + } else { + PasswordVisualTransformation() + }, + supportingText = state.passwordValidationError?.let { + { Text(it.asString(), color = MaterialTheme.colorScheme.error) } + }, + trailingIcon = { + val image = if (state.passwordVisible) Icons.Filled.Visibility + else Icons.Filled.VisibilityOff + val description = if (state.passwordVisible) "Hide password" else "Show password" + IconButton( + onClick = { + processIntent( + ServerSetupIntent.UpdatePasswordVisibility( + state.passwordVisible, + ), + ) + }, + content = { Icon(image, description) }, + ) + }, + maxLines = 1, + ) + } + else -> Unit + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt index 2614a70e..1a60852b 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/Automatic1111Form.kt @@ -3,14 +3,9 @@ package com.shifthackz.aisdv1.presentation.screen.setup.forms import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.filled.Help import androidx.compose.material.icons.filled.DeveloperMode -import androidx.compose.material.icons.filled.Visibility -import androidx.compose.material.icons.filled.VisibilityOff -import androidx.compose.material3.Icon -import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Switch import androidx.compose.material3.Text @@ -19,9 +14,6 @@ import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.font.FontWeight -import androidx.compose.ui.text.input.KeyboardType -import androidx.compose.ui.text.input.PasswordVisualTransformation -import androidx.compose.ui.text.input.VisualTransformation import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.core.model.asString @@ -29,7 +21,6 @@ import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState -import com.shifthackz.aisdv1.presentation.widget.input.DropdownTextField import com.shifthackz.aisdv1.presentation.widget.item.SettingsItem @Composable @@ -69,73 +60,11 @@ fun Automatic1111Form( maxLines = 1, ) if (!state.demoMode) { - DropdownTextField( + AuthCredentialsForm( + state = state, + processIntent = processIntent, modifier = fieldModifier, - label = R.string.auth_title.asUiText(), - items = ServerSetupState.AuthType.entries, - value = state.authType, - onItemSelected = { - processIntent(ServerSetupIntent.UpdateAuthType(it)) - }, - displayDelegate = { type -> - when (type) { - ServerSetupState.AuthType.ANONYMOUS -> R.string.auth_anonymous - ServerSetupState.AuthType.HTTP_BASIC -> R.string.auth_http_basic - }.asUiText() - } ) - when (state.authType) { - ServerSetupState.AuthType.HTTP_BASIC -> { - TextField( - modifier = fieldModifier, - value = state.login, - onValueChange = { - processIntent(ServerSetupIntent.UpdateLogin(it)) - }, - label = { Text(stringResource(id = R.string.hint_login)) }, - isError = state.loginValidationError != null, - supportingText = state.loginValidationError?.let { - { Text(it.asString(), color = MaterialTheme.colorScheme.error) } - }, - maxLines = 1, - ) - TextField( - modifier = fieldModifier, - value = state.password, - onValueChange = { - processIntent(ServerSetupIntent.UpdatePassword(it)) - }, - label = { Text(stringResource(id = R.string.hint_password)) }, - isError = state.passwordValidationError != null, - keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Password), - visualTransformation = if (state.passwordVisible) { - VisualTransformation.None - } else { - PasswordVisualTransformation() - }, - supportingText = state.passwordValidationError?.let { - { Text(it.asString(), color = MaterialTheme.colorScheme.error) } - }, - trailingIcon = { - val image = if (state.passwordVisible) Icons.Filled.Visibility - else Icons.Filled.VisibilityOff - val description = if (state.passwordVisible) "Hide password" else "Show password" - IconButton( - onClick = { - processIntent( - ServerSetupIntent.UpdatePasswordVisibility( - state.passwordVisible, - ), - ) - }, - content = { Icon(image, description) }, - ) - }, - maxLines = 1, - ) - } - else -> Unit - } } SettingsItem( modifier = Modifier @@ -169,10 +98,11 @@ fun Automatic1111Form( ) Text( modifier = Modifier.padding(top = 8.dp, bottom = 16.dp), - text = stringResource( - if (state.demoMode) R.string.hint_demo_mode - else R.string.hint_valid_urls, - ), + text = if (state.demoMode) { + stringResource(R.string.hint_demo_mode) + } else { + stringResource(R.string.hint_valid_urls, "7860") + }, style = MaterialTheme.typography.bodyMedium, ) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt new file mode 100644 index 00000000..a70cf7b5 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/SwarmUiForm.kt @@ -0,0 +1,82 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.forms + +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.filled.Help +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.material3.TextField +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.presentation.R +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState +import com.shifthackz.aisdv1.presentation.widget.item.SettingsItem + +@Composable +fun SwarmUiForm( + modifier: Modifier = Modifier, + state: ServerSetupState, + processIntent: (ServerSetupIntent) -> Unit, +) { + Column( + modifier = modifier + .padding(horizontal = 16.dp), + ) { + Text( + modifier = Modifier + .fillMaxWidth() + .padding(top = 32.dp, bottom = 8.dp), + text = stringResource(id = R.string.hint_swarm_ui_title), + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + fontWeight = FontWeight.Bold, + ) + val fieldModifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp) + TextField( + modifier = fieldModifier, + value = state.swarmUiUrl, + onValueChange = { + processIntent(ServerSetupIntent.UpdateSwarmUiUrl(it)) + }, + label = { Text(stringResource(id = R.string.hint_server_url)) }, + isError = state.swarmUiUrlValidationError != null, + supportingText = state.swarmUiUrlValidationError + ?.let { { Text(it.asString(), color = MaterialTheme.colorScheme.error) } }, + maxLines = 1, + ) + AuthCredentialsForm( + state = state, + processIntent = processIntent, + modifier = fieldModifier, + ) + SettingsItem( + modifier = Modifier + .padding(top = 16.dp) + .fillMaxWidth(), + startIcon = Icons.AutoMirrored.Filled.Help, + text = R.string.settings_item_instructions.asUiText(), + onClick = { processIntent(ServerSetupIntent.LaunchUrl.SwarmUiInstructions) }, + ) + Text( + modifier = Modifier.padding(top = 8.dp), + text = stringResource(id = R.string.hint_args_swarm_ui_warning), + style = MaterialTheme.typography.bodyMedium, + ) + Text( + modifier = Modifier.padding(top = 8.dp, bottom = 16.dp), + text = stringResource(R.string.hint_valid_urls, "7801"), + style = MaterialTheme.typography.bodyMedium, + ) + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt index a817cb54..c2f043dc 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/FeatureTagMapper.kt @@ -11,6 +11,7 @@ fun FeatureTag.mapToUi(): String { id = when (this) { FeatureTag.Txt2Img -> R.string.home_tab_txt_to_img FeatureTag.Img2Img -> R.string.home_tab_img_to_img + FeatureTag.OwnServer -> R.string.hint_own_server FeatureTag.Lora -> R.string.title_lora FeatureTag.TextualInversion -> R.string.title_txt_inversion FeatureTag.HyperNetworks -> R.string.title_hyper_net diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index f78f93d6..4d53ca89 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -12,6 +12,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.forms.HuggingFaceForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.LocalDiffusionForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.OpenAiForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.StabilityAiForm +import com.shifthackz.aisdv1.presentation.screen.setup.forms.SwarmUiForm @Composable fun ConfigurationStep( @@ -52,6 +53,11 @@ fun ConfigurationStep( state = state, processIntent = processIntent, ) + + ServerSource.SWARM_UI -> SwarmUiForm( + state = state, + processIntent = processIntent, + ) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index cb007e51..878640b9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -27,6 +27,15 @@ fun EngineSelectionComponent( onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, ) + ServerSource.SWARM_UI -> DropdownTextField( + label = R.string.hint_sd_model.asUiText(), + loading = state.loading, + modifier = modifier, + value = state.selectedSwarmModel, + items = state.swarmModels, + onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, + ) + ServerSource.HUGGING_FACE -> DropdownTextField( label = R.string.hint_hugging_face_model.asUiText(), loading = state.loading, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt index 7b510934..382a299d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt @@ -11,6 +11,8 @@ data class EngineSelectionState( val mode: ServerSource = ServerSource.AUTOMATIC1111, val sdModels: List = emptyList(), val selectedSdModel: String = "", + val swarmModels: List = emptyList(), + val selectedSwarmModel: String = "", val hfModels: List = emptyList(), val selectedHfModel: String = "", val stEngines: List = emptyList(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index 33582b3b..4b097c13 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.presentation.widget.engine import com.shifthackz.aisdv1.core.common.extensions.EmptyLambda import com.shifthackz.aisdv1.core.common.log.errorLog -import com.shifthackz.aisdv1.core.common.model.Quintuple +import com.shifthackz.aisdv1.core.common.model.Hexagonal import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel @@ -16,6 +16,7 @@ import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseC import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCase +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCase import com.shifthackz.android.core.mvi.EmptyEffect import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.kotlin.subscribeBy @@ -26,6 +27,7 @@ class EngineSelectionViewModel( private val getConfigurationUseCase: GetConfigurationUseCase, private val selectStableDiffusionModelUseCase: SelectStableDiffusionModelUseCase, private val getStableDiffusionModelsUseCase: GetStableDiffusionModelsUseCase, + fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase, observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, @@ -43,6 +45,10 @@ class EngineSelectionViewModel( .onErrorReturn { emptyList() } .toFlowable() + val swarmModels = fetchAndGetSwarmUiModelsUseCase() + .onErrorReturn { emptyList() } + .toFlowable() + val huggingFaceModels = getHuggingFaceModelsUseCase() .onErrorReturn { emptyList() } .toFlowable() @@ -58,16 +64,17 @@ class EngineSelectionViewModel( !Flowable.combineLatest( configuration, a1111Models, + swarmModels, huggingFaceModels, stabilityAiEngines, localAiModels, - ::Quintuple, + ::Hexagonal, ) .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = ::errorLog, onComplete = EmptyLambda, - onNext = { (config, sdModels, hfModels, stEngines, localModels) -> + onNext = { (config, sdModels, swarmModels, hfModels, stEngines, localModels) -> updateState { state -> state.copy( loading = false, @@ -75,6 +82,9 @@ class EngineSelectionViewModel( sdModels = sdModels.map { it.first.title }, selectedSdModel = sdModels.firstOrNull { it.second }?.first?.title ?: state.selectedSdModel, + swarmModels = swarmModels.map { it.name }, + selectedSwarmModel = swarmModels.firstOrNull { it.name == config.swarmUiModel }?.name + ?: state.selectedSwarmModel, hfModels = hfModels.map { it.alias }, selectedHfModel = config.huggingFaceModel, stEngines = stEngines.map { it.id }, @@ -111,6 +121,8 @@ class EngineSelectionViewModel( } } + ServerSource.SWARM_UI -> preferenceManager.swarmUiModel = intent.value + ServerSource.HUGGING_FACE -> preferenceManager.huggingFaceModel = intent.value ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index c0fc3da6..59bb7600 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -141,6 +141,7 @@ fun GenerationInputForm( Column(modifier = modifier) { when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE, ServerSource.LOCAL -> EngineSelectionComponent( @@ -195,6 +196,7 @@ fun GenerationInputForm( // Horde does not support "negative prompt" when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HUGGING_FACE, ServerSource.STABILITY_AI, ServerSource.LOCAL -> { @@ -265,6 +267,7 @@ fun GenerationInputForm( } ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HUGGING_FACE -> { sizeTextFieldsComponent(localModifier) } @@ -285,7 +288,6 @@ fun GenerationInputForm( displayDelegate = { it.key.asUiText() }, ) } - } } @@ -432,9 +434,10 @@ fun GenerationInputForm( ) } } - // Variation seed only supported for A1111 - if (state.mode == ServerSource.AUTOMATIC1111) { - TextField( + // Variation seed supported for A1111, SwarmUI + when (state.mode) { + ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI -> TextField( modifier = Modifier .fillMaxWidth() .padding(top = 8.dp), @@ -458,10 +461,13 @@ fun GenerationInputForm( } }, ) + + else -> Unit } // Sub-seed strength is not available for Local Diffusion when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.HORDE -> { Text( modifier = Modifier.padding(top = 8.dp), @@ -523,6 +529,7 @@ fun GenerationInputForm( when (state.mode) { ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI, ServerSource.STABILITY_AI, ServerSource.HORDE -> afterSlidersSection() @@ -558,7 +565,7 @@ fun GenerationInputForm( } } -fun processTaggedPrompt(keywords: List, event: ChipTextFieldEvent): String { +private fun processTaggedPrompt(keywords: List, event: ChipTextFieldEvent): String { val newKeywords = when (event) { is ChipTextFieldEvent.Add -> buildList { addAll(keywords) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt new file mode 100644 index 00000000..3638bb78 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.presentation.widget.source + +import androidx.compose.runtime.Composable +import androidx.compose.ui.res.stringResource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.presentation.R + +@Composable +fun ServerSource.getName(): String { + return stringResource(id = when (this) { + ServerSource.AUTOMATIC1111 -> R.string.srv_type_own + ServerSource.HORDE -> R.string.srv_type_horde + ServerSource.LOCAL -> R.string.srv_type_local + ServerSource.HUGGING_FACE -> R.string.srv_type_hugging_face + ServerSource.OPEN_AI -> R.string.srv_type_open_ai + ServerSource.STABILITY_AI -> R.string.srv_type_stability_ai + ServerSource.SWARM_UI -> R.string.srv_type_swarm_ui + }) +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt index 83d3fbe4..d46a6ace 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/toolbar/GenearionBottomToolbar.kt @@ -49,26 +49,30 @@ fun GenerationBottomToolbar( .padding(top = 8.dp), contentAlignment = Alignment.BottomCenter, ) { - if (state.mode == ServerSource.AUTOMATIC1111) { - GenerationBottomToolbarBottomLayer( - modifier = Modifier.padding(bottom = 36.dp), - strokeAccentState = strokeAccentState, - state = state, - processIntent = processIntent, - ) - Box( - modifier = Modifier - .fillMaxWidth() - .height(60.dp) - .padding(horizontal = 16.dp) - .clip( - RoundedCornerShape( - topStart = 22.dp, - topEnd = 22.dp, + when (state.mode) { + ServerSource.AUTOMATIC1111, + ServerSource.SWARM_UI -> { + GenerationBottomToolbarBottomLayer( + modifier = Modifier.padding(bottom = 36.dp), + strokeAccentState = strokeAccentState, + state = state, + processIntent = processIntent, + ) + Box( + modifier = Modifier + .fillMaxWidth() + .height(60.dp) + .padding(horizontal = 16.dp) + .clip( + RoundedCornerShape( + topStart = 22.dp, + topEnd = 22.dp, + ) ) - ) - .background(color = MaterialTheme.colorScheme.surface), - ) + .background(color = MaterialTheme.colorScheme.surface), + ) + } + else -> Unit } content() } @@ -159,28 +163,30 @@ private fun GenerationBottomToolbarBottomLayer( color = localColor, style = localStyle, ) - Spacer( - modifier = Modifier - .width(1.dp) - .height(with(LocalDensity.current) { dividerHeight.toDp() }) - .background(color = accentColor), - ) - Text( - modifier = localModifier { - processIntent( - GenerationMviIntent.SetModal( - Modal.ExtraBottomSheet( - state.prompt, - state.negativePrompt, - ExtraType.HyperNet, + if (state.mode == ServerSource.AUTOMATIC1111) { + Spacer( + modifier = Modifier + .width(1.dp) + .height(with(LocalDensity.current) { dividerHeight.toDp() }) + .background(color = accentColor), + ) + Text( + modifier = localModifier { + processIntent( + GenerationMviIntent.SetModal( + Modal.ExtraBottomSheet( + state.prompt, + state.negativePrompt, + ExtraType.HyperNet, + ), ), - ), - ) - }, - text = stringResource(id = R.string.title_hyper_net_short), - textAlign = TextAlign.Center, - color = localColor, - style = localStyle, - ) + ) + }, + text = stringResource(id = R.string.title_hyper_net_short), + textAlign = TextAlign.Center, + color = localColor, + style = localStyle, + ) + } } } diff --git a/presentation/src/main/res/values-ru/strings.xml b/presentation/src/main/res/values-ru/strings.xml index 17830fbf..8908b816 100644 --- a/presentation/src/main/res/values-ru/strings.xml +++ b/presentation/src/main/res/values-ru/strings.xml @@ -62,6 +62,7 @@ Пакетная генерация Поддержка моделей Генерация без интернета + Собственный сервер CFG Шкала: %1$s Метод выборки Стиль @@ -94,9 +95,11 @@ HTTP Базовая Укажите свой URL-адрес Stable Diffusion WebUI - Примеры URL-адресов сервера:\nhttp://192.168.0.2:7860\nhttp://yourdomain.com:7860\nhttps://yourdomain.com + Веб-интерфейс для Stable Diffusion, реализованный с использованием библиотеки Gradio. + Примеры URL-адресов сервера:\nhttp://192.168.0.2:%1$s\nhttp://yourdomain.com:%1$s\nhttps://yourdomain.com Этот режим позволяет проверить поведение программы, даже если у вас нет сервера Stable Diffusion WebUI.\n\nВ демонстрационном режиме программа игнорирует параметры генерации, не использует сервер искусственного интеллекта и возвращает фиктивные изображения. Перед подключением убедитесь, что:\n• вы используете AUTOMATIC1111 WebUI с аргументами --api --listen\n• ваш брандмауэр не блокирует порт 7860\n• телефон подключен к одной сети Wi-Fi с ПК + Перед подключением убедитесь, что:\n• ваш брандмауэр не блокирует порт 7801\n• телефон подключен к одной сети Wi-Fi с ПК Подключиться к Horde AI Horde AI – это краудсорсинговый распределенный кластер нод генерации изображений и текста. @@ -120,6 +123,9 @@ О Stability AI Stability AI движок + Укажите свйой URL-адрес Swarm UI + Модульный веб-интерфейс Stable Diffusion, в котором особое внимание уделяется обеспечению легкого доступа к инструментам, высокой производительности и расширяемости. + Эта конфигурация позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. ВНИМАНИЕ! Функциональность Local Diffusion в бета-тестировании. Не ожидайте высококачественных изображений в локальном режиме. \n\nЭта реализация может не работать должным образом на мобильных телефонах. Производительность и скорость генерации зависят от ресурсов вашего телефона (ЦП, ОЗУ) и размера сгенерированного изображения (чем меньше размер изображения, тем быстрее генерируется). @@ -141,7 +147,7 @@ Выберите SD ML модель Очистить кэш приложения Дебаг Меню - Лора + ЛоРА Инверсия текста Инверсия Редактор тега @@ -259,9 +265,9 @@ Внести битый Base64 в БД Здесь пусто… - Добавьте что-нибуть на AUTOMATIC1111 сервере в: \n\n../models/Lora - Добавьте что-нибуть на AUTOMATIC1111 сервере в: \n\n../models/hypernetworks - Добавьте что-нибуть на AUTOMATIC1111 сервере в: \n\n../embeddings + Добавьте что-нибуть на %1$s сервере в: \n\n%2$s + Добавьте что-нибуть на %1$s сервере в: \n\n%2$s + Добавьте что-нибуть на %1$s сервере в: \n\n%2$s Зарисовка Регулировка diff --git a/presentation/src/main/res/values-tr/strings.xml b/presentation/src/main/res/values-tr/strings.xml index 9ebc6821..e8e48a1b 100644 --- a/presentation/src/main/res/values-tr/strings.xml +++ b/presentation/src/main/res/values-tr/strings.xml @@ -62,6 +62,7 @@ Toplu üretim Çoklu Modeller Çevrimdışı nesil + Kendi Sunucumuz CFG Scale: %1$s Örneklerme Yöntemi Stil ön ayarı @@ -94,9 +95,11 @@ HTTP Temel Lütfen Stable Diffusion WebUI(AUTOMATIC1111) URL adresinizi yazın. - Bazı sunucu örnekleri:\n• http://192.168.0.2:7860\n• http://alanadiniz.com:7860\n• https://alanadiniz.com - This mode allows you to test the application behavior, even if you don\'t have Stable Diffusion WebUI server.\n\nIn demo mode app ignores user prompt, does not use AI server, and returns some mock images. - Before connecting ensure that:\n• you are running AUTOMATIC1111 WebUI with flags --api --listen\n• your firewall is not blocking 7860 port\n• phone is on the same WiFi with your PC + Gradio kütüphanesi kullanılarak uygulanan Stable Diffusion için bir web arayüzü. + Bazı sunucu örnekleri:\n• http://192.168.0.2:%1$s\n• http://alanadiniz.com:%1$s\n• https://alanadiniz.com + Bu mod, Stable Diffusion WebUI sunucunuz olmasa bile uygulama davranışını test etmenize olanak tanır.\n\nDemo modunda uygulama kullanıcı istemini görmezden gelir, AI sunucusunu kullanmaz ve bazı sahte görüntüler döndürür. + Bağlanmadan önce şunlardan emin olun:\n• AUTOMATIC1111 WebUI\'yi --api --listen bayraklarıyla çalıştırıyorsunuz\n• Güvenlik duvarınız 7860 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. + Bağlanmadan önce şunlardan emin olun:\n• Güvenlik duvarınız 7801 portunu engellemiyor\n• Telefon, bilgisayarınızla aynı WiFi\'da. Horde AI bulutuna bağlanın Horde AI, Görüntü oluşturma çalışanları ve metin oluşturma çalışanlarından oluşan kitle kaynaklı dağıtılmış bir kümedir. @@ -120,6 +123,9 @@ Hakkında Stability AI Stability AI motoru + Swarm UI URL\'nizi sağlayın + Araçlara kolay erişim, yüksek performans ve genişletilebilirlik üzerine odaklanan Modüler, Kararlı Yaygın Web Kullanıcı Arayüzü. + Bu yapılandırma, telefonunuzda uzak sunucuya/buluta bağlanmaya gerek kalmadan Stable Diffusion AI nesillerini çalıştırmanıza izin verir. Uyarı! Yerel Yayılma işlevi beta testindedir. Yerel modu kullanarak yüksek kaliteli görüntüler beklemeyin. \n\nBu uygulama, güçlü olmayan telefonlarda iyi çalışmayabilir. Oluşturma performansı ve hızı, telefonunuzun kaynaklarına (CPU, RAM) ve oluşturulan görüntünün boyutuna bağlıdır (görüntü boyutu ne kadar küçükse, oluşturma o kadar hızlı olur). @@ -141,7 +147,7 @@ SD Modeli seçin Uygulama önbelleğini temizle Hata Ayıklama Menüsü - Lora + LoRA Metin İnversiyon İnversiyon Etiketi düzenle @@ -259,9 +265,9 @@ Kötü Base64\'ü DB\'ye yerleştirin Burada hiçbir şey… - AUTOMATIC1111 sunucusuna biraz içerik ekleyin: \n\n../models/Lora - AUTOMATIC1111 sunucusuna biraz içerik ekleyin: \n\n../models/hypernetworks - AUTOMATIC1111 sunucusuna biraz içerik ekleyin: \n\n../embeddings + %1$s sunucusuna biraz içerik ekleyin: \n\n%2$s + %1$s sunucusuna biraz içerik ekleyin: \n\n%2$s + %1$s sunucusuna biraz içerik ekleyin: \n\n%2$s Çizmek Ayarlamak diff --git a/presentation/src/main/res/values-uk/strings.xml b/presentation/src/main/res/values-uk/strings.xml index 2b458faa..50cddb64 100644 --- a/presentation/src/main/res/values-uk/strings.xml +++ b/presentation/src/main/res/values-uk/strings.xml @@ -62,6 +62,7 @@ Пакетна генерація Підтримка моделей Генерація без інтернету + Власний сервер CFG Шкала: %1$s Метод вибірки Стиль @@ -94,9 +95,11 @@ HTTP Базова Вкажіть свою URL-адресу Stable Diffusion WebUI - Ось приклади URL-адрес сервера:\nhttp://192.168.0.2:7860\nhttp://yourdomain.com:7860\nhttps://yourdomain.com + Веб-інтерфейс для Stable Diffusion, реалізований за допомогою бібліотеки Gradio. + Ось приклади URL-адрес сервера:\nhttp://192.168.0.2:%1$s\nhttp://yourdomain.com:%1$s\nhttps://yourdomain.com Цей режим дозволяє перевірити поведінку програми, навіть якщо у вас немає сервера Stable Diffusion WebUI.\n\nУ демонстраційному режимі програма ігнорує параметри генерації, не використовує сервер штучного інтелекту та повертає фіктивні зображення. Перед підключенням переконайтеся, що:\n• ви використовуєте AUTOMATIC1111 WebUI з аргументами --api --listen\n• ваш брандмауер не блокує порт 7860\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК + Перед підключенням переконайтеся, що:\n• ваш брандмауер не блокує порт 7801\n• телефон підключено до однієї мережі Wi-Fi з вашим ПК Підключитися до Horde AI Horde AI — це краудсорсинговий розподілений кластер нод генерації зображень і тексту. @@ -120,6 +123,9 @@ Про Stability AI Stability AI двигун + Provide your Swarm UI URL + Модульний веб-інтерфейс Stable Diffusion з наголосом на полегшення доступу до інструментів, високу продуктивність і розширюваність. + Ця конфігурація дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. УВАГА! Функціональність Local Diffusion у бета-тестуванні. Не очікуйте високоякісних зображень у локальному режимі. \n\nЦя реалізація може не працювати належним чином на телефонах із слабкою потужністю. Продуктивність і швидкість генерації залежать від ресурсів вашого телефону (ЦП, ОЗУ) і розміру згенерованого зображення (чим менший розмір зображення, тим швидше генерується). @@ -141,7 +147,7 @@ Оберіть SD ML модель Очистити кеш додатку Дебаг Меню - Лора + ЛоРА Інверсія тексту Інверсія Редактор тегу @@ -259,9 +265,9 @@ Внести битий Base64 в БД Тут порожньо… - Додайте щось на AUTOMATIC1111 сервері до: \n\n../models/Lora - Додайте щось на AUTOMATIC1111 сервері до: \n\n../models/hypernetworks - Додайте щось на AUTOMATIC1111 сервері до: \n\n../embeddings + Додайте щось на %1$s сервері до: \n\n%2$s + Додайте щось на %1$s сервері до: \n\n%2$s + Додайте щось на %1$s сервері до: \n\n%2$s Креслення Регулювання diff --git a/presentation/src/main/res/values-zh/strings.xml b/presentation/src/main/res/values-zh/strings.xml index e2d9c8bd..76e03036 100644 --- a/presentation/src/main/res/values-zh/strings.xml +++ b/presentation/src/main/res/values-zh/strings.xml @@ -86,6 +86,7 @@ 批量 多个模型 离线生成 + 拥有自己的服务器 CFG缩放: %1$s 采样方法 样式预设 @@ -121,9 +122,11 @@ 提供您的Stable Diffusion WebUI URL - 以下是服务器URL的示例:\n• http://192.168.0.2:7860\n• http://yourdomain.com:7860\n• https://yourdomain.com + 使用 Gradio 库实现的稳定扩散的 Web 界面。 + 以下是服务器URL的示例:\n• http://192.168.0.2:%1$s\n• http://yourdomain.com:%1$s\n• https://yourdomain.com 此模式允许您测试应用程序的行为,即使您没有Stable Diffusion WebUI服务器。\n\n在演示模式下,应用程序忽略用户提示,不使用AI服务器,并返回一些模拟图像。 在连接之前确保:\n• 您正在运行AUTOMATIC1111 WebUI并带有标志 --api --listen\n• 您的防火墙没有阻止7860端口\n• 手机与您的PC在同一WiFi下 + 在连接之前确保:\n• 您的防火墙没有阻止7801端口\n• 手机与您的PC在同一WiFi下 连接到Horde AI云 @@ -151,6 +154,10 @@ 关于Stability AI Stability AI引擎 + + 提供你的 Swarm UI URL + 模块化稳定扩散 Web 用户界面,重点在于使工具易于访问、高性能和可扩展性。 + 本地扩散 此配置允许在您的手机上运行Stable Diffusion AI生成,无需连接到远程服务器/云。 @@ -177,7 +184,7 @@ 选择SD ML模型 清除应用缓存 调试菜单 - Lora + LoRA 超网络 H-Net 文本反转 @@ -320,9 +327,9 @@ 这里什么都没有… - 在AUTOMATIC1111服务器上添加一些内容到:\n\n../models/Lora - 在AUTOMATIC1111服务器上添加一些内容到:\n\n../models/hypernetworks - 在AUTOMATIC1111服务器上添加一些内容到:\n\n../embeddings + 在%1$s服务器上添加一些内容到:\n\n%2$s + 在%1$s服务器上添加一些内容到:\n\n%2$s + 在%1$s服务器上添加一些内容到:\n\n%2$s 修复 diff --git a/presentation/src/main/res/values/strings.xml b/presentation/src/main/res/values/strings.xml index c3b31a7b..fb8b6821 100755 --- a/presentation/src/main/res/values/strings.xml +++ b/presentation/src/main/res/values/strings.xml @@ -69,6 +69,7 @@ HuggingFace Open AI Stability AI + Swarm UI Prompt Negative prompt @@ -77,6 +78,7 @@ Batch Multiple Models Offline generation + Own Server CFG Scale: %1$s Sampling method Style preset @@ -110,9 +112,11 @@ HTTP Basic Provide your Stable Diffusion WebUI URL - Here are the examples of server URLs:\n• http://192.168.0.2:7860\n• http://yourdomain.com:7860\n• https://yourdomain.com + A web interface for Stable Diffusion, implemented using Gradio library. + Here are the examples of server URLs:\n• http://192.168.0.2:%1$s\n• http://yourdomain.com:%1$s\n• https://yourdomain.com This mode allows you to test the application behavior, even if you don\'t have Stable Diffusion WebUI server.\n\nIn demo mode app ignores user prompt, does not use AI server, and returns some mock images. Before connecting ensure that:\n• you are running AUTOMATIC1111 WebUI with flags --api --listen\n• your firewall is not blocking 7860 port\n• phone is on the same WiFi with your PC + Before connecting ensure that:\n• your firewall is not blocking 7801 port\n• phone is on the same WiFi with your PC Connect to Horde AI cloud Horde AI is a crowdsourced distributed cluster of Image generation workers and text generation workers. @@ -136,6 +140,9 @@ About Stability AI Stability AI Engine + Provide your Swarm UI URL + A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility. + Local Diffusion This configuration allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). @@ -158,7 +165,7 @@ Select SD ML Model Clear app cache Debug Menu - Lora + LoRA Hypernetworks H-Net Textual Inversion @@ -280,9 +287,9 @@ Insert bad Base64 in DB Nothing here… - Add some content on AUTOMATIC1111 server to: \n\n../models/Lora - Add some content on AUTOMATIC1111 server to: \n\n../models/hypernetworks - Add some content on AUTOMATIC1111 server to: \n\n../embeddings + Add some content on %1$s server to: \n\n%2$s + Add some content on %1$s server to: \n\n%2$s + Add some content on %1$s server to: \n\n.%2$s Inpaint Draw diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt index 5c83e46e..5b8f9b45 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.presentation.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.domain.entity.Embedding -val mockStableDiffusionEmbeddings = listOf( - StableDiffusionEmbedding("5598"), - StableDiffusionEmbedding("151297"), +val mockEmbeddings = listOf( + Embedding("5598"), + Embedding("151297"), ) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt index dd20509a..8cadd972 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt @@ -1,14 +1,14 @@ package com.shifthackz.aisdv1.presentation.mocks -import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.domain.entity.LoRA val mockStableDiffusionLoras = listOf( - StableDiffusionLora( + LoRA( name = "name_5598", alias = "alias_5598", path = "/unknown", ), - StableDiffusionLora( + LoRA( name = "name_151297", alias = "alias_151297", path = "/unknown", diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt new file mode 100644 index 00000000..167c49a1 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/SwarmUiModelMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel + +val mockSwarmUiModels = listOf( + SwarmUiModel( + name = "5598", + title = "5598", + author = "ShiftHackZ", + ) +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt index 0d5535c8..602363bd 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt @@ -1,8 +1,10 @@ package com.shifthackz.aisdv1.presentation.modal.embedding +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest -import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionEmbeddings +import com.shifthackz.aisdv1.presentation.mocks.mockEmbeddings import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect import com.shifthackz.aisdv1.presentation.model.ErrorState import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider @@ -12,23 +14,35 @@ import io.reactivex.rxjava3.core.Single import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.test.runTest import org.junit.Assert +import org.junit.Before import org.junit.Test class EmbeddingViewModelTest : CoreViewModelTest() { private val stubException = Throwable("Something went wrong.") private val stubFetchAndGetEmbeddingsUseCase = mockk() + private val stubPreferenceManager = mockk() override fun initializeViewModel() = EmbeddingViewModel( fetchAndGetEmbeddingsUseCase = stubFetchAndGetEmbeddingsUseCase, + preferenceManager = stubPreferenceManager, schedulersProvider = stubSchedulersProvider, ) + @Before + override fun initialize() { + super.initialize() + + every { + stubPreferenceManager.source + } returns ServerSource.AUTOMATIC1111 + } + @Test fun `given update data, fetch embeddings successful, expected UI state with embeddings list`() { every { stubFetchAndGetEmbeddingsUseCase() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) viewModel.updateData("prompt", "negative") @@ -106,7 +120,7 @@ class EmbeddingViewModelTest : CoreViewModelTest() { fun `given received ToggleItem intent, expected item from intent isInNegativePrompt changed in UI state`() { every { stubFetchAndGetEmbeddingsUseCase() - } returns Single.just(mockStableDiffusionEmbeddings) + } returns Single.just(mockEmbeddings) viewModel.updateData("prompt", "negative") diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt index 93b872c0..8eed3235 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt @@ -1,6 +1,8 @@ package com.shifthackz.aisdv1.presentation.modal.extras import com.shifthackz.aisdv1.core.common.time.TimeProvider +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest @@ -22,12 +24,14 @@ class ExtrasViewModelTest : CoreViewModelTest() { private val stubException = Throwable("Something went wrong.") private val stubFetchAndGetLorasUseCase = mockk() private val stubFetchAndGetHyperNetworksUseCase = mockk() + private val stubPreferenceManager = mockk() private val stubTimeProvider = mockk() override fun initializeViewModel() = ExtrasViewModel( stubFetchAndGetLorasUseCase, stubFetchAndGetHyperNetworksUseCase, stubSchedulersProvider, + stubPreferenceManager, stubTimeProvider, ) @@ -35,6 +39,10 @@ class ExtrasViewModelTest : CoreViewModelTest() { override fun initialize() { super.initialize() + every { + stubPreferenceManager.source + } returns ServerSource.AUTOMATIC1111 + every { stubTimeProvider.nanoTime() } returns MOCK_SYS_TIME diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt index 891ff322..64527033 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt @@ -58,7 +58,7 @@ class MainRouterImplTest { .assertNoErrors() .assertValueAt(0) { actual -> val expectedRoute = - "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SPLASH.key}" + "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SPLASH.ordinal}" actual is NavigationEffect.Navigate.RouteBuilder && actual.route == expectedRoute } @@ -73,7 +73,7 @@ class MainRouterImplTest { .assertNoErrors() .assertValueAt(0) { actual -> val expectedRoute = - "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SETTINGS.key}" + "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SETTINGS.ordinal}" actual is NavigationEffect.Navigate.RouteBuilder && actual.route == expectedRoute } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index e6bb063c..7be36cd4 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -26,7 +26,6 @@ import io.mockk.verify import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single -import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.test.runTest import org.junit.Assert import org.junit.Before @@ -36,8 +35,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { private val stubGetConfigurationUseCase = mockk() private val stubGetLocalAiModelsUseCase = mockk() - private val stubFetchAndGetHuggingFaceModelsUseCase = - mockk() + private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubUrlValidator = mockk() private val stubCommonStringValidator = mockk() private val stubSetupConnectionInterActor = mockk() @@ -342,16 +340,6 @@ class ServerSetupViewModelTest : CoreViewModelTest() { } } - @Test - fun `given received LaunchManageStoragePermission intent, expected LaunchManageStoragePermission effect delivered to effect collector`() { - viewModel.processIntent(ServerSetupIntent.LaunchManageStoragePermission) - runTest { - val expected = ServerSetupEffect.LaunchManageStoragePermission - val actual = viewModel.effect.firstOrNull() - Assert.assertEquals(expected, actual) - } - } - @Test fun `given received NavigateBack intent, expected router navigateBack() method called`() { every { diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt index cf2d83cb..b85b43fc 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt @@ -11,11 +11,13 @@ import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseC import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCase +import com.shifthackz.aisdv1.domain.usecase.swarmmodel.FetchAndGetSwarmUiModelsUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest import com.shifthackz.aisdv1.presentation.mocks.mockHuggingFaceModels import com.shifthackz.aisdv1.presentation.mocks.mockLocalAiModels import com.shifthackz.aisdv1.presentation.mocks.mockStabilityAiEngines import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.presentation.mocks.mockSwarmUiModels import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider import io.mockk.every import io.mockk.mockk @@ -42,10 +44,9 @@ class EngineSelectionViewModelTest : CoreViewModelTest private val stubSelectStableDiffusionModelUseCase = mockk() private val stubGetStableDiffusionModelsUseCase = mockk() private val stubObserveLocalAiModelsUseCase = mockk() - private val stubFetchAndGetStabilityAiEnginesUseCase = - mockk() - private val stubFetchAndGetHuggingFaceModelsUseCase = - mockk() + private val stubFetchAndGetStabilityAiEnginesUseCase = mockk() + private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() + private val stubFetchAndGetSwarmUiModelsUseCase = mockk() override fun initializeViewModel() = EngineSelectionViewModel( preferenceManager = stubPreferenceManager, @@ -56,6 +57,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest observeLocalAiModelsUseCase = stubObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase = stubFetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, + fetchAndGetSwarmUiModelsUseCase = stubFetchAndGetSwarmUiModelsUseCase, ) @Before @@ -106,6 +108,8 @@ class EngineSelectionViewModelTest : CoreViewModelTest selectedStEngine = "5598", localAiModels = listOf(LocalAiModel.CUSTOM), selectedLocalAiModelId = "CUSTOM", + swarmModels = listOf("5598"), + selectedSwarmModel = "5598", ) val actual = viewModel.state.value Assert.assertEquals(expected, actual) @@ -176,6 +180,21 @@ class EngineSelectionViewModelTest : CoreViewModelTest } } + @Test + fun `given received EngineSelectionIntent, source is SWARM_UI, expected swarmModel changed in preference`() { + mockInitialData(DataTestCase.Mock, ServerSource.SWARM_UI) + + every { + stubPreferenceManager::swarmUiModel.set(any()) + } returns Unit + + viewModel.processIntent(EngineSelectionIntent("151297")) + + verify { + stubPreferenceManager::swarmUiModel.set("151297") + } + } + @Test fun `given received EngineSelectionIntent, source is HUGGING_FACE, expected huggingFaceModel changed in preference`() { mockInitialData(DataTestCase.Mock, ServerSource.HUGGING_FACE) @@ -240,6 +259,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest Configuration( huggingFaceModel = "prompthero/openjourney-v4", stabilityAiEngineId = "5598", + swarmUiModel = "5598", localModelId = "CUSTOM", source = source, ), @@ -273,6 +293,14 @@ class EngineSelectionViewModelTest : CoreViewModelTest DataTestCase.Exception -> Single.error(stubException) } + every { + stubFetchAndGetSwarmUiModelsUseCase() + } returns when (testCase) { + DataTestCase.Mock -> Single.just(mockSwarmUiModels) + DataTestCase.Empty -> Single.just(emptyList()) + DataTestCase.Exception -> Single.error(stubException) + } + stubLocalAiModels.onNext( when (testCase) { DataTestCase.Mock -> Result.success(mockLocalAiModels) diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json index 59134bbd..2f273525 100644 --- a/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json +++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json @@ -2,7 +2,7 @@ "formatVersion": 1, "database": { "version": 1, - "identityHash": "49878af78a080f644679bc65b796475a", + "identityHash": "5dc36f6910330fd3f65e65c4e15e56c2", "entities": [ { "tableName": "server_config", @@ -219,12 +219,50 @@ }, "indices": [], "foreignKeys": [] + }, + { + "tableName": "swarm_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `name` TEXT NOT NULL, `title` TEXT NOT NULL, `author` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "title", + "columnName": "title", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "author", + "columnName": "author", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] } ], "views": [], "setupQueries": [ "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", - "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '49878af78a080f644679bc65b796475a')" + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '5dc36f6910330fd3f65e65c4e15e56c2')" ] } } \ No newline at end of file diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt index 7db368fa..3dc34f74 100755 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt @@ -12,12 +12,14 @@ import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionHyperNetworkDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionLoraDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionModelDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionSamplerDao +import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao import com.shifthackz.aisdv1.storage.db.cache.entity.ServerConfigurationEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionHyperNetworkEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionModelEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntity +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity @Database( version = DB_VERSION, @@ -29,6 +31,7 @@ import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntit StableDiffusionLoraEntity::class, StableDiffusionHyperNetworkEntity::class, StableDiffusionEmbeddingEntity::class, + SwarmUiModelEntity::class, ], ) @TypeConverters( @@ -42,6 +45,7 @@ internal abstract class CacheDatabase : RoomDatabase() { abstract fun sdLoraDao(): StableDiffusionLoraDao abstract fun sdHyperNetworkDao(): StableDiffusionHyperNetworkDao abstract fun sdEmbeddingDao(): StableDiffusionEmbeddingDao + abstract fun swarmUiModelDao(): SwarmUiModelDao companion object { const val DB_VERSION = 1 diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt new file mode 100644 index 00000000..2a900c59 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/contract/SwarmUiModelContract.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.storage.db.cache.contract + +internal object SwarmUiModelContract { + const val TABLE = "swarm_models" + + const val ID = "id" + const val NAME = "name" + const val TITLE = "title" + const val AUTHOR = "author" +} diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt new file mode 100644 index 00000000..bda49f52 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/dao/SwarmUiModelDao.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.storage.db.cache.dao + +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import com.shifthackz.aisdv1.storage.db.cache.contract.SwarmUiModelContract +import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +@Dao +interface SwarmUiModelDao { + + @Query("SELECT * FROM ${SwarmUiModelContract.TABLE}") + fun queryAll(): Single> + + @Insert(onConflict = OnConflictStrategy.REPLACE) + fun insertList(items: List): Completable + + @Query("DELETE FROM ${SwarmUiModelContract.TABLE}") + fun deleteAll(): Completable +} diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt new file mode 100644 index 00000000..f79f12b0 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/cache/entity/SwarmUiModelEntity.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.storage.db.cache.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.shifthackz.aisdv1.storage.db.cache.contract.SwarmUiModelContract + +@Entity(tableName = SwarmUiModelContract.TABLE) +data class SwarmUiModelEntity( + @PrimaryKey(autoGenerate = false) + @ColumnInfo(name = SwarmUiModelContract.ID) + val id: String, + @ColumnInfo(name = SwarmUiModelContract.NAME) + val name: String, + @ColumnInfo(name = SwarmUiModelContract.TITLE) + val title: String, + @ColumnInfo(name = SwarmUiModelContract.AUTHOR) + val author: String, +) diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt index f84172db..68255373 100755 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt @@ -43,6 +43,7 @@ val databaseModule = module { single { get().sdHyperNetworkDao() } single { get().sdEmbeddingDao() } single { get().serverConfigurationDao() } + single { get().swarmUiModelDao() } //endregion //region PERSISTENT DB DAOs