Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/localization/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
<string name="title_txt_inversion">Textual Inversion</string>
<string name="title_txt_inversion_short">Inversion</string>
<string name="title_tag_edit">Edit tag</string>
<string name="title_select_download_source">Select source</string>

<string name="gallery_media_store_banner">You have %1$s photos saved in Download/SDAI</string>
<string name="gallery_info_field_date">Created</string>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ internal class DownloadableModelRepositoryImpl(
private val buildInfoProvider: BuildInfoProvider,
) : DownloadableModelRepository {

override fun download(id: String) = localDataSource
.getById(id)
.flatMapObservable { model ->
remoteDataSource.download(id, model.sources.firstOrNull() ?: "")
}
override fun download(id: String, url: String) = remoteDataSource.download(id, url)

override fun delete(id: String) = localDataSource.delete(id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,29 +219,14 @@ class DownloadableModelRepositoryImplTest {
.assertNotComplete()
}

@Test
fun `given attempt to download model, local data source has no such model, expected error value`() {
every {
stubLocalDataSource.getById(any())
} returns Single.error(stubException)

repository
.download("5598")
.test()
.assertNoValues()
.assertError(stubException)
.await()
.assertNotComplete()
}

@Test
fun `given attempt to download model, local data source has such model, download succeeds, expected unknown, downloading, complete values`() {
every {
stubLocalDataSource.getById(any())
} returns Single.just(mockLocalAiModel)

val stubObserver = repository
.download("5598")
.download("5598", "https://moroz.cc/stub.zip")
.test()

stubDownloadState.onNext(DownloadState.Unknown)
Expand Down Expand Up @@ -276,7 +261,7 @@ class DownloadableModelRepositoryImplTest {
} returns Single.just(mockLocalAiModel)

val stubObserver = repository
.download("5598")
.download("5598", "https://moroz.cc/stub.zip")
.test()

stubDownloadState.onNext(DownloadState.Unknown)
Expand Down Expand Up @@ -309,7 +294,7 @@ class DownloadableModelRepositoryImplTest {
} returns Observable.error(stubException)

repository
.download("5598")
.download("5598", "https://moroz.cc/stub.zip")
.test()
.assertError(stubException)
.assertNoValues()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalModelUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase
Expand Down Expand Up @@ -185,6 +187,7 @@ internal val useCasesModule = module {
factoryOf(::FetchAndGetStabilityAiEnginesUseCaseImpl) bind FetchAndGetStabilityAiEnginesUseCase::class
factoryOf(::FetchAndGetSupportersUseCaseImpl) bind FetchAndGetSupportersUseCase::class
factoryOf(::SendReportUseCaseImpl) bind SendReportUseCase::class
factoryOf(::GetLocalModelUseCaseImpl) bind GetLocalModelUseCase::class
}

internal val interActorsModule = module {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single

interface DownloadableModelRepository {
fun download(id: String): Observable<DownloadState>
fun download(id: String, url: String): Observable<DownloadState>
fun delete(id: String): Completable
fun getAllOnnx(): Single<List<LocalAiModel>>
fun getAllMediaPipe(): Single<List<LocalAiModel>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ import com.shifthackz.aisdv1.domain.entity.DownloadState
import io.reactivex.rxjava3.core.Observable

interface DownloadModelUseCase {
operator fun invoke(id: String): Observable<DownloadState>
operator fun invoke(id: String, url: String): Observable<DownloadState>
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ internal class DownloadModelUseCaseImpl(
private val downloadableModelRepository: DownloadableModelRepository,
) : DownloadModelUseCase {

override fun invoke(id: String) = downloadableModelRepository.download(id)
override fun invoke(id: String, url: String) = downloadableModelRepository.download(id, url)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.shifthackz.aisdv1.domain.usecase.downloadable

import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Single

interface GetLocalModelUseCase {
operator fun invoke(id: String): Single<LocalAiModel>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.shifthackz.aisdv1.domain.usecase.downloadable

import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Single

internal class GetLocalModelUseCaseImpl(
private val localDataSource: DownloadableModelDataSource.Local,
) : GetLocalModelUseCase {

override fun invoke(id: String): Single<LocalAiModel> = localDataSource.getById(id)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class DownloadModelUseCaseImplTest {

@Before
fun initialize() {
whenever(stubRepository.download(any()))
whenever(stubRepository.download(any(), any()))
.thenReturn(stubDownloadStatus)
}

@Test
fun `given download running, then finishes successfully, expected final state is Complete`() {
val stubObserver = useCase("5598").test()
val stubObserver = useCase("5598", "https://moroz.cc/stub.zip").test()

stubDownloadStatus.onNext(DownloadState.Unknown)

Expand Down Expand Up @@ -58,7 +58,7 @@ class DownloadModelUseCaseImplTest {

@Test
fun `given download running, then fails, expected final state is Error`() {
val stubObserver = useCase("5598").test()
val stubObserver = useCase("5598", "https://moroz.cc/stub.zip").test()

stubDownloadStatus.onNext(DownloadState.Unknown)

Expand Down Expand Up @@ -87,7 +87,7 @@ class DownloadModelUseCaseImplTest {

@Test
fun `given download running, then fails, then user restarts download, then completes, expected state Error on 1st try, final state is Complete`() {
val stubObserver = useCase("5598").test()
val stubObserver = useCase("5598", "https://moroz.cc/stub.zip").test()

stubDownloadStatus.onNext(DownloadState.Unknown)

Expand Down Expand Up @@ -140,10 +140,10 @@ class DownloadModelUseCaseImplTest {

@Test
fun `given observable terminated with unexpected error, expected error value`() {
whenever(stubRepository.download(any()))
whenever(stubRepository.download(any(), any()))
.thenReturn(Observable.error(stubTerminateException))

useCase("5598")
useCase("5598", "https://moroz.cc/stub.zip")
.test()
.assertError(stubTerminateException)
.await()
Expand Down
4 changes: 2 additions & 2 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[versions]
versionName = "0.6.7"
versionCode = "188"
versionName = "0.6.8"
versionCode = "190"
targetSdk = "34"
compileSdk = "35"
minSdk = "24"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.shifthackz.aisdv1.presentation.di

import com.shifthackz.aisdv1.presentation.activity.AiStableDiffusionViewModel
import com.shifthackz.aisdv1.presentation.modal.download.DownloadDialogViewModel
import com.shifthackz.aisdv1.presentation.modal.embedding.EmbeddingViewModel
import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasViewModel
import com.shifthackz.aisdv1.presentation.modal.history.InputHistoryViewModel
Expand Down Expand Up @@ -53,6 +54,7 @@ val viewModelModule = module {
viewModelOf(::DonateViewModel)
viewModelOf(::BackgroundWorkViewModel)
viewModelOf(::LoggerViewModel)
viewModelOf(::DownloadDialogViewModel)

viewModel { parameters ->
OnBoardingViewModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.shifthackz.aisdv1.presentation.core.GenerationFormUpdateEvent
import com.shifthackz.aisdv1.presentation.core.GenerationMviIntent
import com.shifthackz.aisdv1.presentation.core.ImageToImageIntent
import com.shifthackz.aisdv1.presentation.modal.crop.CropImageModal
import com.shifthackz.aisdv1.presentation.modal.download.DownloadDialog
import com.shifthackz.aisdv1.presentation.modal.embedding.EmbeddingScreen
import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasScreen
import com.shifthackz.aisdv1.presentation.modal.grid.GridBottomSheet
Expand Down Expand Up @@ -367,5 +368,14 @@ fun ModalRenderer(
}
)
}

is Modal.SelectDownloadSource -> DownloadDialog(
modelId = screenModal.modelId,
onDismissRequest = dismiss,
onDownloadSourceSelected = { url ->
processIntent(ServerSetupIntent.LocalModel.DownloadConfirm(screenModal.modelId, url))
dismiss()
}
)
}
}
Loading