diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quintuple.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quintuple.kt new file mode 100644 index 00000000..e0987f1c --- /dev/null +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/model/Quintuple.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.core.common.model + +import java.io.Serializable + +data class Quintuple( + val first: A, + val second: B, + val third: C, + val fourth: D, + val fifth: E, +) : Serializable { + + override fun toString(): String = "($first, $second, $third, $fourth, $fifth)" +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 9fdedf85..5284b92d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File @@ -21,6 +22,7 @@ internal class DownloadableModelLocalDataSource( private val preferenceManager: PreferenceManager, private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelDataSource.Local { + override fun getAll(): Single> = dao.query() .map(List::mapEntityToDomain) .map { models -> @@ -45,6 +47,17 @@ internal class DownloadableModelLocalDataSource( .flatMap(::getById) .onErrorResumeNext { Single.error(Throwable("No selected model")) } + override fun observeAll(): Flowable> = dao + .observe() + .map(List::mapEntityToDomain) + .map { models -> + buildList { + addAll(models) + if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM) + } + } + .flatMap { models -> models.withLocalData().toFlowable() } + override fun select(id: String): Completable = Completable.fromAction { preferenceManager.localModelId = id } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt index 5ee0d356..5bde3662 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt @@ -25,5 +25,8 @@ internal class DownloadableModelRepositoryImpl( .onErrorResumeNext { localDataSource.getAll() } override fun getById(id: String) = localDataSource.getById(id) + + override fun observeAll() = localDataSource.observeAll() + override fun select(id: String) = localDataSource.select(id) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt index 801d14b9..b2320323 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.datasource import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single @@ -17,6 +18,7 @@ sealed interface DownloadableModelDataSource { fun getAll(): Single> fun getById(id: String): Single fun getSelected(): Single + fun observeAll(): Flowable> fun select(id: String): Completable fun save(list: List): Completable fun isDownloaded(id: String): Single diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index a8fa3515..23338d53 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -34,6 +34,8 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCase @@ -136,6 +138,7 @@ internal val useCasesModule = module { factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class + factoryOf(::ObserveLocalAiModelsUseCaseImpl) bind ObserveLocalAiModelsUseCase::class factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt index d5a50419..dcab5955 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.repository import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single @@ -12,5 +13,6 @@ interface DownloadableModelRepository { fun delete(id: String): Completable fun getAll(): Single> fun getById(id: String): Single + fun observeAll(): Flowable> fun select(id: String): Completable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt new file mode 100644 index 00000000..f79d31b1 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import io.reactivex.rxjava3.core.Flowable + +interface ObserveLocalAiModelsUseCase { + operator fun invoke(): Flowable> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt new file mode 100644 index 00000000..21f14410 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository + +internal class ObserveLocalAiModelsUseCaseImpl( + private val repository: DownloadableModelRepository, +) : ObserveLocalAiModelsUseCase { + + override fun invoke() = repository.observeAll() +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt index 8756c7d5..54c9f1b0 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt @@ -1,13 +1,17 @@ package com.shifthackz.aisdv1.presentation.screen.setup.steps -import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect import androidx.compose.ui.Modifier import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState import com.shifthackz.aisdv1.presentation.screen.setup.components.ConfigurationModeButton @@ -18,21 +22,31 @@ fun SourceSelectionStep( state: ServerSetupState, processIntent: (ServerSetupIntent) -> Unit = {}, ) { - BaseServerSetupStateWrapper(modifier) { - Column { - Spacer(modifier = Modifier.height(12.dp)) - state.allowedModes.forEach { mode -> - ConfigurationModeButton( - modifier = Modifier - .fillMaxWidth() - .padding(horizontal = 16.dp, vertical = 4.dp), - state = state, - mode = mode, - onClick = { - processIntent(ServerSetupIntent.UpdateServerMode(it)) - }, - ) - } + val lazyListState = rememberLazyListState() + LaunchedEffect(state.mode) { + // Adding 1 here, because item with index == 0 is top spacer + lazyListState.animateScrollToItem(state.mode.ordinal + 1) + } + LazyColumn( + modifier = modifier, + state = lazyListState, + ) { + item(key = "SPACER_TOP") { Spacer(modifier = Modifier.height(12.dp)) } + items( + items = state.allowedModes, + key = ServerSource::key, + ) { mode -> + ConfigurationModeButton( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp, vertical = 4.dp), + state = state, + mode = mode, + onClick = { + processIntent(ServerSetupIntent.UpdateServerMode(it)) + }, + ) } + item(key = "SPACER_BOTTOM") { Spacer(modifier = Modifier.height(32.dp)) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index 0b16ed2c..cb007e51 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -45,7 +45,18 @@ fun EngineSelectionComponent( onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, ) - else -> Unit + ServerSource.LOCAL -> DropdownTextField( + label = R.string.hint_sd_model.asUiText(), + loading = state.loading, + modifier = modifier, + value = state.localAiModels.firstOrNull { it.id == state.selectedLocalAiModelId }, + items = state.localAiModels, + onItemSelected = { intentHandler(EngineSelectionIntent(it.id)) }, + displayDelegate = { it.name.asUiText() }, + ) + + ServerSource.HORDE -> Unit + ServerSource.OPEN_AI -> Unit } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt index 414ce03e..7b510934 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionState.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.widget.engine import androidx.compose.runtime.Immutable +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.android.core.mvi.MviState @@ -14,4 +15,6 @@ data class EngineSelectionState( val selectedHfModel: String = "", val stEngines: List = emptyList(), val selectedStEngine: String = "", + val localAiModels: List = emptyList(), + val selectedLocalAiModelId: String = "", ): MviState diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index ba18e7fe..af32aeee 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -2,12 +2,14 @@ package com.shifthackz.aisdv1.presentation.widget.engine import com.shifthackz.aisdv1.core.common.extensions.EmptyLambda import com.shifthackz.aisdv1.core.common.log.errorLog -import com.shifthackz.aisdv1.core.common.model.Quadruple +import com.shifthackz.aisdv1.core.common.model.Quintuple import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase @@ -23,6 +25,7 @@ class EngineSelectionViewModel( private val getConfigurationUseCase: GetConfigurationUseCase, private val selectStableDiffusionModelUseCase: SelectStableDiffusionModelUseCase, private val getStableDiffusionModelsUseCase: GetStableDiffusionModelsUseCase, + observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, ) : MviRxViewModel() { @@ -46,28 +49,37 @@ class EngineSelectionViewModel( .onErrorReturn { emptyList() } .toFlowable() + val localAiModels = observeLocalAiModelsUseCase() + .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CUSTOM.id } } + .onErrorReturn { emptyList() } + !Flowable.combineLatest( configuration, a1111Models, huggingFaceModels, stabilityAiEngines, - ::Quadruple, + localAiModels, + ::Quintuple, ) .subscribeOnMainThread(schedulersProvider) .subscribeBy( onError = ::errorLog, onComplete = EmptyLambda, - onNext = { (config, sdModels, hfModels, stEngines) -> + onNext = { (config, sdModels, hfModels, stEngines, localModels) -> updateState { state -> state.copy( loading = false, mode = config.source, sdModels = sdModels.map { it.first.title }, - selectedSdModel = sdModels.first { it.second }.first.title, + selectedSdModel = sdModels.firstOrNull { it.second }?.first?.title + ?: state.selectedSdModel, hfModels = hfModels.map { it.alias }, selectedHfModel = config.huggingFaceModel, stEngines = stEngines.map { it.id }, selectedStEngine = config.stabilityAiEngineId, + localAiModels = localModels, + selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localModelId }?.id + ?: state.selectedLocalAiModelId ) } }, @@ -82,7 +94,7 @@ class EngineSelectionViewModel( it.copy( loading = true, selectedSdModel = intent.value, - ) + ) } } .andThen(getStableDiffusionModelsUseCase()) @@ -101,6 +113,8 @@ class EngineSelectionViewModel( ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value + ServerSource.LOCAL -> preferenceManager.localModelId = intent.value + else -> Unit } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/DropdownTextField.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/DropdownTextField.kt index 7f7b7b3f..ad2ae7d0 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/DropdownTextField.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/DropdownTextField.kt @@ -30,7 +30,7 @@ fun DropdownTextField( modifier: Modifier = Modifier, loading: Boolean = false, label: UiText = UiText.empty, - value: T, + value: T?, items: List = emptyList(), onItemSelected: (T) -> Unit = {}, displayDelegate: (T) -> UiText = { t -> t.toString().asUiText() }, @@ -46,7 +46,7 @@ fun DropdownTextField( modifier = Modifier .fillMaxWidth() .menuAnchor(), - value = displayDelegate(value).asString(), + value = value?.let { displayDelegate(it).asString() } ?: "", onValueChange = {}, readOnly = true, label = { Text(label.asString()) }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index 8f2ea769..c0fc3da6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -142,7 +142,8 @@ fun GenerationInputForm( when (state.mode) { ServerSource.AUTOMATIC1111, ServerSource.STABILITY_AI, - ServerSource.HUGGING_FACE -> EngineSelectionComponent( + ServerSource.HUGGING_FACE, + ServerSource.LOCAL -> EngineSelectionComponent( modifier = Modifier .fillMaxWidth() .padding(top = 8.dp), diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt index 9336a544..435d4c69 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt @@ -7,6 +7,7 @@ import androidx.room.Query import com.shifthackz.aisdv1.storage.db.persistent.contract.LocalModelContract import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Single @Dao @@ -15,6 +16,9 @@ interface LocalModelDao { @Query("SELECT * FROM ${LocalModelContract.TABLE}") fun query(): Single> + @Query("SELECT * FROM ${LocalModelContract.TABLE}") + fun observe(): Flowable> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.ID} = :id LIMIT 1") fun queryById(id: String): Single