Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.shifthackz.aisdv1.core.common.model

import java.io.Serializable

data class Quintuple<out A, out B, out C, out D, out E>(
val first: A,
val second: B,
val third: C,
val fourth: D,
val fifth: E,
) : Serializable {

override fun toString(): String = "($first, $second, $third, $fourth, $fifth)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao
import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import java.io.File
Expand All @@ -21,6 +22,7 @@ internal class DownloadableModelLocalDataSource(
private val preferenceManager: PreferenceManager,
private val buildInfoProvider: BuildInfoProvider,
) : DownloadableModelDataSource.Local {

override fun getAll(): Single<List<LocalAiModel>> = dao.query()
.map(List<LocalModelEntity>::mapEntityToDomain)
.map { models ->
Expand All @@ -45,6 +47,17 @@ internal class DownloadableModelLocalDataSource(
.flatMap(::getById)
.onErrorResumeNext { Single.error(Throwable("No selected model")) }

override fun observeAll(): Flowable<List<LocalAiModel>> = dao
.observe()
.map(List<LocalModelEntity>::mapEntityToDomain)
.map { models ->
buildList {
addAll(models)
if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM)
}
}
.flatMap { models -> models.withLocalData().toFlowable() }

override fun select(id: String): Completable = Completable.fromAction {
preferenceManager.localModelId = id
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,8 @@ internal class DownloadableModelRepositoryImpl(
.onErrorResumeNext { localDataSource.getAll() }

override fun getById(id: String) = localDataSource.getById(id)

override fun observeAll() = localDataSource.observeAll()

override fun select(id: String) = localDataSource.select(id)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.datasource
import com.shifthackz.aisdv1.domain.entity.DownloadState
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single

Expand All @@ -17,6 +18,7 @@ sealed interface DownloadableModelDataSource {
fun getAll(): Single<List<LocalAiModel>>
fun getById(id: String): Single<LocalAiModel>
fun getSelected(): Single<LocalAiModel>
fun observeAll(): Flowable<List<LocalAiModel>>
fun select(id: String): Completable
fun save(list: List<LocalAiModel>): Completable
fun isDownloaded(id: String): Single<Boolean>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCase
Expand Down Expand Up @@ -136,6 +138,7 @@ internal val useCasesModule = module {
factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class
factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class
factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class
factoryOf(::ObserveLocalAiModelsUseCaseImpl) bind ObserveLocalAiModelsUseCase::class
factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class
factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class
factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.repository
import com.shifthackz.aisdv1.domain.entity.DownloadState
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single

Expand All @@ -12,5 +13,6 @@ interface DownloadableModelRepository {
fun delete(id: String): Completable
fun getAll(): Single<List<LocalAiModel>>
fun getById(id: String): Single<LocalAiModel>
fun observeAll(): Flowable<List<LocalAiModel>>
fun select(id: String): Completable
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.shifthackz.aisdv1.domain.usecase.downloadable

import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Flowable

interface ObserveLocalAiModelsUseCase {
operator fun invoke(): Flowable<List<LocalAiModel>>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.shifthackz.aisdv1.domain.usecase.downloadable

import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository

internal class ObserveLocalAiModelsUseCaseImpl(
private val repository: DownloadableModelRepository,
) : ObserveLocalAiModelsUseCase {

override fun invoke() = repository.observeAll()
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package com.shifthackz.aisdv1.presentation.screen.setup.steps

import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
import com.shifthackz.aisdv1.presentation.screen.setup.components.ConfigurationModeButton
Expand All @@ -18,21 +22,31 @@ fun SourceSelectionStep(
state: ServerSetupState,
processIntent: (ServerSetupIntent) -> Unit = {},
) {
BaseServerSetupStateWrapper(modifier) {
Column {
Spacer(modifier = Modifier.height(12.dp))
state.allowedModes.forEach { mode ->
ConfigurationModeButton(
modifier = Modifier
.fillMaxWidth()
.padding(horizontal = 16.dp, vertical = 4.dp),
state = state,
mode = mode,
onClick = {
processIntent(ServerSetupIntent.UpdateServerMode(it))
},
)
}
val lazyListState = rememberLazyListState()
LaunchedEffect(state.mode) {
// Adding 1 here, because item with index == 0 is top spacer
lazyListState.animateScrollToItem(state.mode.ordinal + 1)
}
LazyColumn(
modifier = modifier,
state = lazyListState,
) {
item(key = "SPACER_TOP") { Spacer(modifier = Modifier.height(12.dp)) }
items(
items = state.allowedModes,
key = ServerSource::key,
) { mode ->
ConfigurationModeButton(
modifier = Modifier
.fillMaxWidth()
.padding(horizontal = 16.dp, vertical = 4.dp),
state = state,
mode = mode,
onClick = {
processIntent(ServerSetupIntent.UpdateServerMode(it))
},
)
}
item(key = "SPACER_BOTTOM") { Spacer(modifier = Modifier.height(32.dp)) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,18 @@ fun EngineSelectionComponent(
onItemSelected = { intentHandler(EngineSelectionIntent(it)) },
)

else -> Unit
ServerSource.LOCAL -> DropdownTextField(
label = R.string.hint_sd_model.asUiText(),
loading = state.loading,
modifier = modifier,
value = state.localAiModels.firstOrNull { it.id == state.selectedLocalAiModelId },
items = state.localAiModels,
onItemSelected = { intentHandler(EngineSelectionIntent(it.id)) },
displayDelegate = { it.name.asUiText() },
)

ServerSource.HORDE -> Unit
ServerSource.OPEN_AI -> Unit
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.shifthackz.aisdv1.presentation.widget.engine

import androidx.compose.runtime.Immutable
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.android.core.mvi.MviState

Expand All @@ -14,4 +15,6 @@ data class EngineSelectionState(
val selectedHfModel: String = "",
val stEngines: List<String> = emptyList(),
val selectedStEngine: String = "",
val localAiModels: List<LocalAiModel> = emptyList(),
val selectedLocalAiModelId: String = "",
): MviState
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package com.shifthackz.aisdv1.presentation.widget.engine

import com.shifthackz.aisdv1.core.common.extensions.EmptyLambda
import com.shifthackz.aisdv1.core.common.log.errorLog
import com.shifthackz.aisdv1.core.common.model.Quadruple
import com.shifthackz.aisdv1.core.common.model.Quintuple
import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread
import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase
Expand All @@ -23,6 +25,7 @@ class EngineSelectionViewModel(
private val getConfigurationUseCase: GetConfigurationUseCase,
private val selectStableDiffusionModelUseCase: SelectStableDiffusionModelUseCase,
private val getStableDiffusionModelsUseCase: GetStableDiffusionModelsUseCase,
observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase,
fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase,
getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase,
) : MviRxViewModel<EngineSelectionState, EngineSelectionIntent, EmptyEffect>() {
Expand All @@ -46,28 +49,37 @@ class EngineSelectionViewModel(
.onErrorReturn { emptyList() }
.toFlowable()

val localAiModels = observeLocalAiModelsUseCase()
.map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CUSTOM.id } }
.onErrorReturn { emptyList() }

!Flowable.combineLatest(
configuration,
a1111Models,
huggingFaceModels,
stabilityAiEngines,
::Quadruple,
localAiModels,
::Quintuple,
)
.subscribeOnMainThread(schedulersProvider)
.subscribeBy(
onError = ::errorLog,
onComplete = EmptyLambda,
onNext = { (config, sdModels, hfModels, stEngines) ->
onNext = { (config, sdModels, hfModels, stEngines, localModels) ->
updateState { state ->
state.copy(
loading = false,
mode = config.source,
sdModels = sdModels.map { it.first.title },
selectedSdModel = sdModels.first { it.second }.first.title,
selectedSdModel = sdModels.firstOrNull { it.second }?.first?.title
?: state.selectedSdModel,
hfModels = hfModels.map { it.alias },
selectedHfModel = config.huggingFaceModel,
stEngines = stEngines.map { it.id },
selectedStEngine = config.stabilityAiEngineId,
localAiModels = localModels,
selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localModelId }?.id
?: state.selectedLocalAiModelId
)
}
},
Expand All @@ -82,7 +94,7 @@ class EngineSelectionViewModel(
it.copy(
loading = true,
selectedSdModel = intent.value,
)
)
}
}
.andThen(getStableDiffusionModelsUseCase())
Expand All @@ -101,6 +113,8 @@ class EngineSelectionViewModel(

ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value

ServerSource.LOCAL -> preferenceManager.localModelId = intent.value

else -> Unit
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fun <T : Any> DropdownTextField(
modifier: Modifier = Modifier,
loading: Boolean = false,
label: UiText = UiText.empty,
value: T,
value: T?,
items: List<T> = emptyList(),
onItemSelected: (T) -> Unit = {},
displayDelegate: (T) -> UiText = { t -> t.toString().asUiText() },
Expand All @@ -46,7 +46,7 @@ fun <T : Any> DropdownTextField(
modifier = Modifier
.fillMaxWidth()
.menuAnchor(),
value = displayDelegate(value).asString(),
value = value?.let { displayDelegate(it).asString() } ?: "",
onValueChange = {},
readOnly = true,
label = { Text(label.asString()) },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ fun GenerationInputForm(
when (state.mode) {
ServerSource.AUTOMATIC1111,
ServerSource.STABILITY_AI,
ServerSource.HUGGING_FACE -> EngineSelectionComponent(
ServerSource.HUGGING_FACE,
ServerSource.LOCAL -> EngineSelectionComponent(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import androidx.room.Query
import com.shifthackz.aisdv1.storage.db.persistent.contract.LocalModelContract
import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Single

@Dao
Expand All @@ -15,6 +16,9 @@ interface LocalModelDao {
@Query("SELECT * FROM ${LocalModelContract.TABLE}")
fun query(): Single<List<LocalModelEntity>>

@Query("SELECT * FROM ${LocalModelContract.TABLE}")
fun observe(): Flowable<List<LocalModelEntity>>

@Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.ID} = :id LIMIT 1")
fun queryById(id: String): Single<LocalModelEntity>

Expand Down