From 588090b65f95067e28219bbf76a2aebaf059fd94 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 26 Aug 2024 00:09:46 +0300 Subject: [PATCH 01/10] Mediapipe patch 1 --- app/build.gradle.kts | 1 + app/src/main/AndroidManifest.xml | 4 + .../shifthackz/aisdv1/app/di/FeatureModule.kt | 10 +- .../aisdv1/app/di/ProvidersModule.kt | 1 + .../common/file/FileProviderDescriptor.kt | 1 + .../src/main/res/values/strings.xml | 2 + .../aisdv1/data/di/RepositoryModule.kt | 3 + .../MediaPipeGenerationRepositoryImpl.kt | 43 ++++ .../aisdv1/domain/di/DomainModule.kt | 3 + .../aisdv1/domain/entity/ServerSource.kt | 8 + .../domain/feature/mediapipe/MediaPipe.kt | 10 + .../settings/SetupConnectionInterActor.kt | 2 + .../settings/SetupConnectionInterActorImpl.kt | 2 + .../MediaPipeGenerationRepository.kt | 9 + .../generation/TextToImageUseCaseImpl.kt | 3 + .../settings/ConnectToMediaPipeUseCase.kt | 7 + .../settings/ConnectToMediaPipeUseCaseImpl.kt | 21 ++ feature/mediapipe/.gitignore | 1 + feature/mediapipe/build.gradle.kts | 25 +++ feature/mediapipe/consumer-rules.pro | 0 feature/mediapipe/proguard-rules.pro | 21 ++ .../mediapipe/src/main/AndroidManifest.xml | 4 + .../datatransport/AutoValue_Event.java | 27 +++ .../android/datatransport/Encoding.java | 18 ++ .../google/android/datatransport/Event.java | 48 +++++ .../android/datatransport/Priority.java | 10 + .../android/datatransport/Transformer.java | 5 + .../android/datatransport/Transport.java | 7 + .../datatransport/TransportFactory.java | 8 + .../TransportScheduleCallback.java | 7 + .../datatransport/cct/CCTDestination.java | 6 + .../runtime/TransportRuntime.java | 12 ++ .../runtime/backends/TransportBackend.java | 9 + .../google/common/flogger/FluentLogger.java | 9 + .../google/firebase/encoders/DataEncoder.java | 16 ++ .../com/google/firebase/encoders/Encoder.java | 11 + .../firebase/encoders/EncodingException.java | 14 ++ .../firebase/encoders/FieldDescriptor.java | 59 +++++ .../firebase/encoders/ObjectEncoder.java | 3 + .../encoders/ObjectEncoderContext.java | 56 +++++ .../firebase/encoders/ValueEncoder.java | 3 + .../encoders/ValueEncoderContext.java | 37 ++++ .../encoders/config/Configurator.java | 7 + .../encoders/config/EncoderConfig.java | 14 ++ .../encoders/json/JsonDataEncoderBuilder.java | 66 ++++++ .../json/JsonValueObjectEncoderContext.java | 201 ++++++++++++++++++ .../firebase/encoders/proto/AtProtobuf.java | 48 +++++ .../proto/LengthCountingOutputStream.java | 27 +++ .../firebase/encoders/proto/ProtoEnum.java | 7 + .../firebase/encoders/proto/Protobuf.java | 14 ++ .../proto/ProtobufDataEncoderContext.java | 151 +++++++++++++ .../encoders/proto/ProtobufEncoder.java | 69 ++++++ .../proto/ProtobufValueEncoderContext.java | 64 ++++++ .../aisdv1/feature/mediapipe/MediaPipeImpl.kt | 57 +++++ .../feature/mediapipe/di/MediaPipeModule.kt | 11 + .../mediapipe/extensions/ModelPaths.kt | 2 + .../feature/mediapipe/ExampleUnitTest.kt | 16 ++ gradle/libs.versions.toml | 2 + .../screen/settings/SettingsScreen.kt | 1 + .../screen/setup/ServerSetupViewModel.kt | 9 + .../screen/setup/steps/ConfigurationStep.kt | 2 + .../widget/engine/EngineSelectionComponent.kt | 1 + .../widget/input/GenerationInputForm.kt | 1 + .../widget/source/ServerSourceLabel.kt | 1 + settings.gradle.kts | 5 +- 65 files changed, 1316 insertions(+), 6 deletions(-) create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt create mode 100644 feature/mediapipe/.gitignore create mode 100644 feature/mediapipe/build.gradle.kts create mode 100644 feature/mediapipe/consumer-rules.pro create mode 100644 feature/mediapipe/proguard-rules.pro create mode 100644 feature/mediapipe/src/main/AndroidManifest.xml create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java create mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java create mode 100644 feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java create mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java create mode 100644 feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt create mode 100644 feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt create mode 100644 feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt create mode 100644 feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 5aeb348d..cb9d95fc 100755 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -85,6 +85,7 @@ dependencies { implementation(project(":domain")) implementation(project(":feature:auth")) implementation(project(":feature:diffusion")) + implementation(project(":feature:mediapipe")) implementation(project(":feature:work")) implementation(project(":data")) implementation(project(":demo")) diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index 67b1befd..4aac9eb0 100755 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -12,6 +12,10 @@ android:theme="@style/Theme.AiSdCompose.Splash" android:usesCleartextTraffic="true"> + + + + Horde Local Diffusion (Beta) Local + Google AI MediaPipe + MediaPipe Hugging Face Inference HuggingFace Open AI diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index 6e08ea30..21d8edda 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl +import com.shifthackz.aisdv1.data.repository.MediaPipeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl @@ -33,6 +34,7 @@ import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository import com.shifthackz.aisdv1.domain.repository.LorasRepository +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.RandomImageRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository @@ -63,6 +65,7 @@ val repositoryModule = module { singleOf(::TemporaryGenerationResultRepositoryImpl) bind TemporaryGenerationResultRepository::class factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class + factoryOf(::MediaPipeGenerationRepositoryImpl) bind MediaPipeGenerationRepository::class factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt new file mode 100644 index 00000000..4bac9dcc --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.core.CoreGenerationRepository +import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.schedulers.Schedulers + +internal class MediaPipeGenerationRepositoryImpl( + mediaStoreGateway: MediaStoreGateway, + base64ToBitmapConverter: Base64ToBitmapConverter, + localDataSource: GenerationResultDataSource.Local, + backgroundWorkObserver: BackgroundWorkObserver, + preferenceManager: PreferenceManager, + private val mediaPipe: MediaPipe, + private val bitmapToBase64Converter: BitmapToBase64Converter, +) : CoreGenerationRepository( + mediaStoreGateway = mediaStoreGateway, + base64ToBitmapConverter = base64ToBitmapConverter, + localDataSource = localDataSource, + preferenceManager = preferenceManager, + backgroundWorkObserver = backgroundWorkObserver, +), MediaPipeGenerationRepository { + + override fun generateFromText(payload: TextToImagePayload): Single = mediaPipe + .process(payload) + .subscribeOn(Schedulers.computation()) + .map(BitmapToBase64Converter::Input) + .flatMap(bitmapToBase64Converter::invoke) + .map(BitmapToBase64Converter.Output::base64ImageString) + .map { base64 -> payload to base64 } + .map(Pair::mapLocalDiffusionToAiGenResult) + .flatMap(::insertGenerationResult) +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index bccb6ebd..535ee266 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -92,6 +92,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase @@ -164,6 +166,7 @@ internal val useCasesModule = module { factoryOf(::InterruptGenerationUseCaseImpl) bind InterruptGenerationUseCase::class factoryOf(::ConnectToHordeUseCaseImpl) bind ConnectToHordeUseCase::class factoryOf(::ConnectToLocalDiffusionUseCaseImpl) bind ConnectToLocalDiffusionUseCase::class + factoryOf(::ConnectToMediaPipeUseCaseImpl) bind ConnectToMediaPipeUseCase::class factoryOf(::ConnectToA1111UseCaseImpl) bind ConnectToA1111UseCase::class factoryOf(::ConnectToSwarmUiUseCaseImpl) bind ConnectToSwarmUiUseCase::class factoryOf(::ConnectToHuggingFaceUseCaseImpl) bind ConnectToHuggingFaceUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 4eeca344..3ebf9ac9 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -69,6 +69,14 @@ enum class ServerSource( FeatureTag.Txt2Img, FeatureTag.MultipleModels, ), + ), + LOCAL_GOOGLE_MEDIA_PIPE( + key = "local_google_media_pipe", + featureTags = setOf( + FeatureTag.Offline, + FeatureTag.Txt2Img, + FeatureTag.MultipleModels, + ), ); companion object { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt new file mode 100644 index 00000000..f2275b25 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.feature.mediapipe + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +interface MediaPipe { + fun process(payload: TextToImagePayload): Single + +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt index 4c42f81c..5959a035 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase @@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase interface SetupConnectionInterActor { val connectToHorde: ConnectToHordeUseCase val connectToLocal: ConnectToLocalDiffusionUseCase + val connectToMediaPipe: ConnectToMediaPipeUseCase val connectToA1111: ConnectToA1111UseCase val connectToHuggingFace: ConnectToHuggingFaceUseCase val connectToOpenAi: ConnectToOpenAiUseCase diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt index 05517631..306da094 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase @@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase internal data class SetupConnectionInterActorImpl( override val connectToHorde: ConnectToHordeUseCase, override val connectToLocal: ConnectToLocalDiffusionUseCase, + override val connectToMediaPipe: ConnectToMediaPipeUseCase, override val connectToA1111: ConnectToA1111UseCase, override val connectToHuggingFace: ConnectToHuggingFaceUseCase, override val connectToOpenAi: ConnectToOpenAiUseCase, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt new file mode 100644 index 00000000..ac6088e5 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +interface MediaPipeGenerationRepository { + fun generateFromText(payload: TextToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index 45be1d30..2d38a598 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository @@ -22,6 +23,7 @@ internal class TextToImageUseCaseImpl( private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, private val swarmUiGenerationRepository: SwarmUiGenerationRepository, private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, + private val mediaPipeGenerationRepository: MediaPipeGenerationRepository, private val preferenceManager: PreferenceManager, ) : TextToImageUseCase { @@ -40,5 +42,6 @@ internal class TextToImageUseCaseImpl( ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromText(payload) ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromText(payload) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> mediaPipeGenerationRepository.generateFromText(payload) } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt new file mode 100644 index 00000000..a1f04a74 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import io.reactivex.rxjava3.core.Single + +interface ConnectToMediaPipeUseCase { + operator fun invoke(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt new file mode 100644 index 00000000..cf2d2a3d --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt @@ -0,0 +1,21 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.entity.ServerSource +import io.reactivex.rxjava3.core.Single + +internal class ConnectToMediaPipeUseCaseImpl( + private val getConfigurationUseCase: GetConfigurationUseCase, + private val setServerConfigurationUseCase: SetServerConfigurationUseCase, +) : ConnectToMediaPipeUseCase { + + override fun invoke() = getConfigurationUseCase() + .map { originalConfiguration -> + originalConfiguration.copy( + source = ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + ) + } + .flatMapCompletable(setServerConfigurationUseCase::invoke) + .andThen(Single.just(Result.success(Unit))) + .onErrorResumeNext { t -> Single.just(Result.failure(t)) } + +} diff --git a/feature/mediapipe/.gitignore b/feature/mediapipe/.gitignore new file mode 100644 index 00000000..42afabfd --- /dev/null +++ b/feature/mediapipe/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/feature/mediapipe/build.gradle.kts b/feature/mediapipe/build.gradle.kts new file mode 100644 index 00000000..9612f9b1 --- /dev/null +++ b/feature/mediapipe/build.gradle.kts @@ -0,0 +1,25 @@ +plugins { + alias(libs.plugins.generic.library) +} + +android { + namespace = "com.shifthackz.aisdv1.feature.diffusion" +} + +dependencies { + implementation(project(":core:common")) + implementation(project(":domain")) + implementation(libs.koin.core) + implementation(libs.rx.kotlin) + implementation(libs.google.mediapipe.image.generator) { + exclude(group = "com.google.firebase", module = "firebase-encoders") + exclude(group = "com.google.firebase", module = "firebase-encoders-json") + exclude(group = "com.google.firebase", module = "firebase-encoders-proto") + exclude(group = "com.google.flogger", module = "flogger") + exclude(group = "com.google.flogger", module = "flogger-system-backend") + exclude(group = "com.google.android.datatransport", module = "transport-api") + exclude(group = "com.google.android.datatransport", module = "transport-backend-cct") + exclude(group = "com.google.android.datatransport", module = "transport-runtime") + } + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1") +} diff --git a/feature/mediapipe/consumer-rules.pro b/feature/mediapipe/consumer-rules.pro new file mode 100644 index 00000000..e69de29b diff --git a/feature/mediapipe/proguard-rules.pro b/feature/mediapipe/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/feature/mediapipe/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/feature/mediapipe/src/main/AndroidManifest.xml b/feature/mediapipe/src/main/AndroidManifest.xml new file mode 100644 index 00000000..44008a43 --- /dev/null +++ b/feature/mediapipe/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java new file mode 100644 index 00000000..58ecd456 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java @@ -0,0 +1,27 @@ +package com.google.android.datatransport; + +import androidx.annotation.Nullable; + +public class AutoValue_Event extends Event { + + AutoValue_Event(@Nullable Integer code, T payload, Priority priority) { + + + } + + @Nullable + @Override + public Integer getCode() { + return 0; + } + + @Override + public T getPayload() { + return null; + } + + @Override + public Priority getPriority() { + return null; + } +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java new file mode 100644 index 00000000..4cbbb057 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java @@ -0,0 +1,18 @@ +package com.google.android.datatransport; + +import androidx.annotation.NonNull; + +public final class Encoding { + + public static Encoding of(@NonNull String name) { + return new Encoding(name); + } + + public String getName() { + return ""; + } + + private Encoding(@NonNull String name) { + + } +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java new file mode 100644 index 00000000..5cbd01c0 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java @@ -0,0 +1,48 @@ +package com.google.android.datatransport; + + +import androidx.annotation.Nullable; + +public abstract class Event { + public Event() { + } + + @Nullable + public abstract Integer getCode(); + + @Nullable + public abstract T getPayload(); + + @Nullable + public abstract Priority getPriority(); + + @Nullable + public static Event ofData(int code, T payload) { + return null; + } + + @Nullable + public static Event ofData(T payload) { + return null; + } + + @Nullable + public static Event ofTelemetry(int code, T value) { + return null; + } + + @Nullable + public static Event ofTelemetry(T value) { + return null; + } + + @Nullable + public static Event ofUrgent(int code, T value) { + return null; + } + + @Nullable + public static Event ofUrgent(T value) { + return null; + } +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java new file mode 100644 index 00000000..6bef1467 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java @@ -0,0 +1,10 @@ +package com.google.android.datatransport; + +public enum Priority { + DEFAULT, + VERY_LOW, + HIGHEST; + + private Priority() { + } +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java new file mode 100644 index 00000000..426d2986 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java @@ -0,0 +1,5 @@ +package com.google.android.datatransport; + +public interface Transformer { + U apply(T var1); +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java new file mode 100644 index 00000000..5bdf02d2 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java @@ -0,0 +1,7 @@ +package com.google.android.datatransport; + +public interface Transport { + void send(Event var1); + + void schedule(Event var1, TransportScheduleCallback var2); +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java new file mode 100644 index 00000000..a092b671 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java @@ -0,0 +1,8 @@ +package com.google.android.datatransport; + +public interface TransportFactory { + @Deprecated + Transport getTransport(String var1, Class var2, Transformer var3); + + Transport getTransport(String var1, Class var2, Encoding var3, Transformer var4); +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java new file mode 100644 index 00000000..8a06f9aa --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java @@ -0,0 +1,7 @@ +package com.google.android.datatransport; + +import androidx.annotation.Nullable; + +public interface TransportScheduleCallback { + void onSchedule(@Nullable Exception var1); +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java new file mode 100644 index 00000000..6e5fd45a --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java @@ -0,0 +1,6 @@ +package com.google.android.datatransport.cct; + +public class CCTDestination { + + public static CCTDestination INSTANCE = new CCTDestination(); +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java new file mode 100644 index 00000000..2c15ecf9 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java @@ -0,0 +1,12 @@ +package com.google.android.datatransport.runtime; + +import android.content.Context; + +public class TransportRuntime { + + public static void initialize(Context applicationContext) {} + + public static TransportRuntime getInstance() { + return new TransportRuntime(); + } +} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java new file mode 100644 index 00000000..e2e45b07 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java @@ -0,0 +1,9 @@ +//package com.google.android.datatransport.runtime.backends; +// +//import com.google.android.datatransport.runtime.EventInternal; +// +//public interface TransportBackend { +// EventInternal decorate(EventInternal var1); +// +// BackendResponse send(BackendRequest var1); +//} diff --git a/feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java b/feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java new file mode 100644 index 00000000..b1b88f39 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java @@ -0,0 +1,9 @@ +package com.google.common.flogger; + +public class FluentLogger { + + public static FluentLogger forEnclosingClass() { + return new FluentLogger(); + } +} +//com/google/firebase/encoders/json/JsonDataEncoderBuilder; \ No newline at end of file diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java new file mode 100644 index 00000000..2cab88ee --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java @@ -0,0 +1,16 @@ +package com.google.firebase.encoders; + +import androidx.annotation.NonNull; + +import java.io.IOException; +import java.io.Writer; + +public interface DataEncoder { + + /** Encodes {@code obj} into {@code writer}. */ + void encode(@NonNull Object obj, @NonNull Writer writer) throws IOException; + + /** Returns the string-encoded representation of {@code obj}. */ + @NonNull + String encode(@NonNull Object obj); +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java new file mode 100644 index 00000000..25568ade --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java @@ -0,0 +1,11 @@ +package com.google.firebase.encoders; + +import androidx.annotation.NonNull; + +import java.io.IOException; + +interface Encoder { + + /** Encode {@code obj} using {@code TContext}. */ + void encode(@NonNull TValue obj, @NonNull TContext context) throws IOException; +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java new file mode 100644 index 00000000..9ad3c896 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java @@ -0,0 +1,14 @@ +package com.google.firebase.encoders; + +import androidx.annotation.NonNull; + +public final class EncodingException extends RuntimeException { + + public EncodingException(@NonNull String message) { + super(message); + } + + public EncodingException(@NonNull String message, @NonNull Exception cause) { + super(message, cause); + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java new file mode 100644 index 00000000..1bad9234 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java @@ -0,0 +1,59 @@ +package com.google.firebase.encoders; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.lang.annotation.Annotation; +import java.util.Collections; +import java.util.Map; + +public final class FieldDescriptor { + + private FieldDescriptor(String name, Map, Object> properties) { + } + + /** Name of the field. */ + @NonNull + public String getName() { + return ""; + } + + /** + * Provides access to extra properties of the field. + * + * @return {@code T} annotation if present, null otherwise. + */ + @Nullable + @SuppressWarnings("unchecked") + public T getProperty(@NonNull Class type) { + return null; + } + + @NonNull + public static FieldDescriptor of(@NonNull String name) { + return new FieldDescriptor(name, Collections.emptyMap()); + } + + @NonNull + public static Builder builder(@NonNull String name) { + return new Builder(name); + } + + public static final class Builder { + + + Builder(String name) { + + } + + @NonNull + public Builder withProperty(@NonNull T value) { + return this; + } + + @NonNull + public FieldDescriptor build() { + return new FieldDescriptor("", Collections.emptyMap()); + } + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java new file mode 100644 index 00000000..12271efc --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java @@ -0,0 +1,3 @@ +package com.google.firebase.encoders; + +public interface ObjectEncoder extends Encoder {} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java new file mode 100644 index 00000000..77936279 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java @@ -0,0 +1,56 @@ +package com.google.firebase.encoders; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.IOException; + +public interface ObjectEncoderContext { + + @Deprecated + @NonNull + ObjectEncoderContext add(@NonNull String name, @Nullable Object obj) throws IOException; + + @Deprecated + @NonNull + ObjectEncoderContext add(@NonNull String name, double value) throws IOException; + + @Deprecated + @NonNull + ObjectEncoderContext add(@NonNull String name, int value) throws IOException; + + @Deprecated + @NonNull + ObjectEncoderContext add(@NonNull String name, long value) throws IOException; + + @Deprecated + @NonNull + ObjectEncoderContext add(@NonNull String name, boolean value) throws IOException; + + @NonNull + ObjectEncoderContext add(@NonNull FieldDescriptor field, @Nullable Object obj) throws IOException; + + @NonNull + ObjectEncoderContext add(@NonNull FieldDescriptor field, float value) throws IOException; + + @NonNull + ObjectEncoderContext add(@NonNull FieldDescriptor field, double value) throws IOException; + + @NonNull + ObjectEncoderContext add(@NonNull FieldDescriptor field, int value) throws IOException; + + @NonNull + ObjectEncoderContext add(@NonNull FieldDescriptor field, long value) throws IOException; + + @NonNull + ObjectEncoderContext add(@NonNull FieldDescriptor field, boolean value) throws IOException; + + @NonNull + ObjectEncoderContext inline(@Nullable Object value) throws IOException; + + @NonNull + ObjectEncoderContext nested(@NonNull String name) throws IOException; + + @NonNull + ObjectEncoderContext nested(@NonNull FieldDescriptor field) throws IOException; +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java new file mode 100644 index 00000000..48ae1168 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java @@ -0,0 +1,3 @@ +package com.google.firebase.encoders; + +public interface ValueEncoder extends Encoder {} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java new file mode 100644 index 00000000..85a72c4d --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java @@ -0,0 +1,37 @@ +package com.google.firebase.encoders; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.IOException; + +public interface ValueEncoderContext { + + /** Adds {@code value} as a primitive encoded value. */ + @NonNull + ValueEncoderContext add(@Nullable String value) throws IOException; + + /** Adds {@code value} as a primitive encoded value. */ + @NonNull + ValueEncoderContext add(float value) throws IOException; + + /** Adds {@code value} as a primitive encoded value. */ + @NonNull + ValueEncoderContext add(double value) throws IOException; + + /** Adds {@code value} as a primitive encoded value. */ + @NonNull + ValueEncoderContext add(int value) throws IOException; + + /** Adds {@code value} as a primitive encoded value. */ + @NonNull + ValueEncoderContext add(long value) throws IOException; + + /** Adds {@code value} as a primitive encoded value. */ + @NonNull + ValueEncoderContext add(boolean value) throws IOException; + + /** Adds {@code value} as a encoded array of bytes. */ + @NonNull + ValueEncoderContext add(@NonNull byte[] bytes) throws IOException; +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java new file mode 100644 index 00000000..6332c0fe --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java @@ -0,0 +1,7 @@ +package com.google.firebase.encoders.config; + +import androidx.annotation.NonNull; + +public interface Configurator { + void configure(@NonNull EncoderConfig configuration); +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java new file mode 100644 index 00000000..308e28ba --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java @@ -0,0 +1,14 @@ +package com.google.firebase.encoders.config; + +import androidx.annotation.NonNull; + +import com.google.firebase.encoders.ObjectEncoder; +import com.google.firebase.encoders.ValueEncoder; + +public interface EncoderConfig> { + @NonNull + T registerEncoder(@NonNull Class type, @NonNull ObjectEncoder encoder); + + @NonNull + T registerEncoder(@NonNull Class type, @NonNull ValueEncoder encoder); +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java new file mode 100644 index 00000000..2f6ebf5a --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java @@ -0,0 +1,66 @@ +package com.google.firebase.encoders.json; + +import androidx.annotation.NonNull; + +import com.google.firebase.encoders.DataEncoder; +import com.google.firebase.encoders.ObjectEncoder; +import com.google.firebase.encoders.ValueEncoder; +import com.google.firebase.encoders.config.Configurator; +import com.google.firebase.encoders.config.EncoderConfig; + +import java.io.IOException; +import java.io.Writer; + +public class JsonDataEncoderBuilder implements EncoderConfig { + + public JsonDataEncoderBuilder() { + } + + @NonNull + @Override + public JsonDataEncoderBuilder registerEncoder( + @NonNull Class clazz, @NonNull ObjectEncoder objectEncoder) { + + return this; + } + + @NonNull + @Override + public JsonDataEncoderBuilder registerEncoder( + @NonNull Class clazz, @NonNull ValueEncoder encoder) { + + return this; + } + + /** Encoder used if no encoders are found among explicitly registered ones. */ + @NonNull + public JsonDataEncoderBuilder registerFallbackEncoder( + @NonNull ObjectEncoder fallbackEncoder) { + return this; + } + + @NonNull + public JsonDataEncoderBuilder configureWith(@NonNull Configurator config) { + return this; + } + + @NonNull + public JsonDataEncoderBuilder ignoreNullValues(boolean ignore) { + return this; + } + + @NonNull + public DataEncoder build() { + return new DataEncoder() { + @Override + public void encode(@NonNull Object o, @NonNull Writer writer) throws IOException { + + } + + @Override + public String encode(@NonNull Object o) { + return ""; + } + }; + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java new file mode 100644 index 00000000..1b697bf8 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java @@ -0,0 +1,201 @@ +package com.google.firebase.encoders.json; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import com.google.firebase.encoders.EncodingException; +import com.google.firebase.encoders.FieldDescriptor; +import com.google.firebase.encoders.ObjectEncoder; +import com.google.firebase.encoders.ObjectEncoderContext; +import com.google.firebase.encoders.ValueEncoder; +import com.google.firebase.encoders.ValueEncoderContext; + +import java.io.IOException; +import java.io.Writer; +import java.util.Map; + +public class JsonValueObjectEncoderContext implements ObjectEncoderContext, ValueEncoderContext { + + JsonValueObjectEncoderContext( + @NonNull Writer writer, + @NonNull Map, ObjectEncoder> objectEncoders, + @NonNull Map, ValueEncoder> valueEncoders, + ObjectEncoder fallbackEncoder, + boolean ignoreNullValues) { + + } + + JsonValueObjectEncoderContext() {} + + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@NonNull String name, @Nullable Object o) + throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@NonNull String name, double value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@NonNull String name, int value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@NonNull String name, long value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@NonNull String name, boolean value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, @Nullable Object obj) + throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, float value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, double value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, int value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, long value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, boolean value) + throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext inline(@Nullable Object value) throws IOException { + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext nested(@NonNull String name) throws IOException { + + return new JsonValueObjectEncoderContext(); + } + + @NonNull + @Override + public ObjectEncoderContext nested(@NonNull FieldDescriptor field) throws IOException { + return nested(field.getName()); + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@Nullable String value) throws IOException { + + return this; + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(float value) throws IOException { + + + return this; + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(double value) throws IOException { + + return this; + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(int value) throws IOException { + + return this; + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(long value) throws IOException { + + return this; + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(boolean value) throws IOException { + + return this; + } + + @NonNull + @Override + public JsonValueObjectEncoderContext add(@Nullable byte[] bytes) throws IOException { + + return this; + } + + @NonNull + JsonValueObjectEncoderContext add(@Nullable Object o, boolean inline) throws IOException { + + return new JsonValueObjectEncoderContext(); + } + + JsonValueObjectEncoderContext doEncode(ObjectEncoder encoder, Object o, boolean inline) + throws IOException { + return this; + } + + private boolean cannotBeInline(Object value) { + return true; + } + + void close() throws IOException { + + } + + private void maybeUnNest() throws IOException { + + } + + private JsonValueObjectEncoderContext internalAdd(@NonNull String name, @Nullable Object o) + throws IOException, EncodingException { + return add(o, false); + } + + private JsonValueObjectEncoderContext internalAddIgnoreNullValues( + @NonNull String name, @Nullable Object o) throws IOException, EncodingException { + return add(o, false); + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java new file mode 100644 index 00000000..cdf7a87f --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java @@ -0,0 +1,48 @@ +package com.google.firebase.encoders.proto; + +import java.lang.annotation.Annotation; + +public class AtProtobuf { + public AtProtobuf() { + + } + + public AtProtobuf tag(int tag) { + + return this; + } + + public AtProtobuf intEncoding(Protobuf.IntEncoding intEncoding) { + + return this; + } + + public static AtProtobuf builder() { + return new AtProtobuf(); + } + + public Protobuf build() { + return new ProtobufImpl(); + } + + private static final class ProtobufImpl implements Protobuf { + + ProtobufImpl(int tag, Protobuf.IntEncoding intEncoding) { + + } + + ProtobufImpl() {} + + public Class annotationType() { + return Protobuf.class; + } + + public int tag() { + return 0; + } + + public Protobuf.IntEncoding intEncoding() { + return IntEncoding.DEFAULT; + } + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java new file mode 100644 index 00000000..0d13838f --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java @@ -0,0 +1,27 @@ +package com.google.firebase.encoders.proto; + +import androidx.annotation.NonNull; + +import java.io.OutputStream; + +public class LengthCountingOutputStream extends OutputStream { + + @Override + public void write(int b) { + + } + + @Override + public void write(byte[] b) { + + } + + @Override + public void write(@NonNull byte[] b, int off, int len) { + + } + + long getLength() { + return 0L; + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java new file mode 100644 index 00000000..f3499ccf --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java @@ -0,0 +1,7 @@ +package com.google.firebase.encoders.proto; + +public interface ProtoEnum { + + /** Numeric representation of the Enum. */ + int getNumber(); +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java new file mode 100644 index 00000000..f7bc6234 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java @@ -0,0 +1,14 @@ +package com.google.firebase.encoders.proto; + +public @interface Protobuf { + int tag(); + + /** Specifies numeric field encoding. */ + IntEncoding intEncoding() default IntEncoding.DEFAULT; + + enum IntEncoding { + DEFAULT, + SIGNED, + FIXED + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java new file mode 100644 index 00000000..b9a063b2 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java @@ -0,0 +1,151 @@ +package com.google.firebase.encoders.proto; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import com.google.firebase.encoders.EncodingException; +import com.google.firebase.encoders.FieldDescriptor; +import com.google.firebase.encoders.ObjectEncoder; +import com.google.firebase.encoders.ObjectEncoderContext; +import com.google.firebase.encoders.ValueEncoder; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Map; + +public class ProtobufDataEncoderContext implements ObjectEncoderContext { + + ProtobufDataEncoderContext( + OutputStream output, + Map, ObjectEncoder> objectEncoders, + Map, ValueEncoder> valueEncoders, + ObjectEncoder fallbackEncoder) { + + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull String name, @Nullable Object obj) throws IOException { + return add(FieldDescriptor.of(name), obj); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull String name, double value) throws IOException { + return add(FieldDescriptor.of(name), value); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull String name, int value) throws IOException { + return add(FieldDescriptor.of(name), value); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull String name, long value) throws IOException { + return add(FieldDescriptor.of(name), value); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull String name, boolean value) throws IOException { + return add(FieldDescriptor.of(name), value); + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, @Nullable Object obj) + throws IOException { + return add(field, obj, true); + } + + ObjectEncoderContext add( + @NonNull FieldDescriptor field, @Nullable Object obj, boolean skipDefault) + throws IOException { + + return this; + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, double value) throws IOException { + return add(field, value, true); + } + + ObjectEncoderContext add(@NonNull FieldDescriptor field, double value, boolean skipDefault) + throws IOException { + return this; + } + + @NonNull + @Override + public ObjectEncoderContext add(@NonNull FieldDescriptor field, float value) throws IOException { + + return add(field, value, true); + } + + ObjectEncoderContext add(@NonNull FieldDescriptor field, float value, boolean skipDefault) + throws IOException { + return this; + } + + @NonNull + @Override + public ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, int value) + throws IOException { + return add(field, value, true); + } + + ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, int value, boolean skipDefault) + throws IOException { + return this; + } + + @NonNull + @Override + public ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, long value) + throws IOException { + return add(field, value, true); + } + + ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, long value, boolean skipDefault) + throws IOException { + return this; + } + + @NonNull + @Override + public ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, boolean value) + throws IOException { + return add(field, value, true); + } + + ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, boolean value, boolean skipDefault) + throws IOException { + return add(field, value ? 1 : 0, skipDefault); + } + + @NonNull + @Override + public ObjectEncoderContext inline(@Nullable Object value) throws IOException { + return encode(value); + } + + ProtobufDataEncoderContext encode(@Nullable Object value) throws IOException { + throw new EncodingException("No encoder for " + value.getClass()); + } + + @NonNull + @Override + public ObjectEncoderContext nested(@NonNull String name) throws IOException { + return nested(FieldDescriptor.of(name)); + } + + @NonNull + @Override + public ObjectEncoderContext nested(@NonNull FieldDescriptor field) throws IOException { + throw new EncodingException("nested() is not implemented for protobuf encoding."); + } + +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java new file mode 100644 index 00000000..1378abca --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java @@ -0,0 +1,69 @@ +package com.google.firebase.encoders.proto; + +import androidx.annotation.NonNull; + +import com.google.firebase.encoders.ObjectEncoder; +import com.google.firebase.encoders.ValueEncoder; +import com.google.firebase.encoders.config.Configurator; +import com.google.firebase.encoders.config.EncoderConfig; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Map; + +public class ProtobufEncoder { + + ProtobufEncoder( + Map, ObjectEncoder> objectEncoders, + Map, ValueEncoder> valueEncoders, + ObjectEncoder fallbackEncoder) { + } + + ProtobufEncoder() {} + + /** Encodes an arbitrary object and directly writes into the output stream. */ + public void encode(@NonNull Object value, @NonNull OutputStream outputStream) throws IOException { + } + + /** Encodes an arbitrary object and returns it as a byte array. */ + @NonNull + public byte[] encode(@NonNull Object value) { + return new byte[0]; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder implements EncoderConfig { + + @NonNull + @Override + public Builder registerEncoder( + @NonNull Class type, @NonNull ObjectEncoder encoder) { + return this; + } + + @NonNull + @Override + public Builder registerEncoder( + @NonNull Class type, @NonNull ValueEncoder encoder) { + return this; + } + + @NonNull + public Builder registerFallbackEncoder(@NonNull ObjectEncoder fallbackEncoder) { + return this; + } + + @NonNull + public Builder configureWith(@NonNull Configurator config) { + config.configure(this); + return this; + } + + public ProtobufEncoder build() { + return new ProtobufEncoder(); + } + } +} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java new file mode 100644 index 00000000..dc2dcb33 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java @@ -0,0 +1,64 @@ +package com.google.firebase.encoders.proto; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import com.google.firebase.encoders.FieldDescriptor; +import com.google.firebase.encoders.ValueEncoderContext; + +import java.io.IOException; + +public class ProtobufValueEncoderContext implements ValueEncoderContext { + + ProtobufValueEncoderContext(ProtobufDataEncoderContext objEncoderCtx) { + + } + + ProtobufValueEncoderContext() {} + + void resetContext(FieldDescriptor field, boolean skipDefault) { + + } + + @NonNull + @Override + public ValueEncoderContext add(@Nullable String value) throws IOException { + return this; + } + + @NonNull + @Override + public ValueEncoderContext add(float value) throws IOException { + return this; + } + + @NonNull + @Override + public ValueEncoderContext add(double value) throws IOException { + return this; + } + + @NonNull + @Override + public ValueEncoderContext add(int value) throws IOException { + return this; + } + + @NonNull + @Override + public ValueEncoderContext add(long value) throws IOException { + return this; + } + + @NonNull + @Override + public ValueEncoderContext add(boolean value) throws IOException { + return this; + } + + @NonNull + @Override + public ValueEncoderContext add(@NonNull byte[] bytes) throws IOException { + return this; + } +} diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt new file mode 100644 index 00000000..32e898fb --- /dev/null +++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -0,0 +1,57 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import android.content.Context +import android.graphics.Bitmap +import com.google.mediapipe.framework.image.BitmapExtractor +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import io.reactivex.rxjava3.core.Single + +internal class MediaPipeImpl( + private val context: Context, + private val fileProviderDescriptor: FileProviderDescriptor, +): MediaPipe { + + private var imageGenerator: ImageGenerator? = null + + override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> + try { + initialize() + println("Generating...") + val result = imageGenerator?.generate( + payload.prompt, + payload.samplingSteps, + payload.seed.toIntOrNull() ?: 0, + ) + println("Extracting bitmap...") + val bitmap = BitmapExtractor.extract(result?.generatedImage()) + println("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") + close() + if (!emitter.isDisposed) emitter.onSuccess(bitmap) + } catch (e: Exception) { + close() + if (!emitter.isDisposed) emitter.onError(e) + } + } + + private fun initialize(): ImageGenerator { + val options = ImageGeneratorOptions.builder() + .setImageGeneratorModelDirectory(fileProviderDescriptor.mediaPipeDirPath) + .build() + + val generator = ImageGenerator.createFromOptions(context, options) + imageGenerator = generator + println("Initialized successfully! Path: ${fileProviderDescriptor.mediaPipeDirPath}") + return generator + } + + private fun close() = runCatching { + println("Closing...") + imageGenerator?.close() + imageGenerator = null + println("Session closed!") + } +} diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt new file mode 100644 index 00000000..3327827e --- /dev/null +++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.feature.mediapipe.di + +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.feature.mediapipe.MediaPipeImpl +import org.koin.core.module.dsl.factoryOf +import org.koin.dsl.bind +import org.koin.dsl.module + +val mediaPipeModule = module { + factoryOf(::MediaPipeImpl) bind MediaPipe::class +} diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt new file mode 100644 index 00000000..a92380c8 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt @@ -0,0 +1,2 @@ +package com.shifthackz.aisdv1.feature.mediapipe.extensions + diff --git a/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt b/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt new file mode 100644 index 00000000..412481c5 --- /dev/null +++ b/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import org.junit.Assert.* +import org.junit.Test + +/** + * Example local unit test, which will execute on the development machine (host). + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +class ExampleUnitTest { + @Test + fun addition_isCorrect() { + assertEquals(4, 2 + 2) + } +} \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d2f0faac..e35f09d0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -46,6 +46,7 @@ catppuccin = "0.1.2" turbine = "1.1.0" roboelectric = "4.13" testCoroutines = "1.8.1" +mediaPipeGenerator = "0.10.14" [libraries] android-tools-build-gradle = { group = "com.android.tools.build", name = "gradle", version.ref = "agp"} @@ -77,6 +78,7 @@ androidx-work-runtime = { group = "androidx.work", name = "work-runtime", versio google-gson = { group = "com.google.code.gson", name = "gson", version.ref = "gson" } google-material = { group = "com.google.android.material", name = "material", version.ref = "material" } google-accompanist-systemuicontroller = { group = "com.google.accompanist", name = "accompanist-systemuicontroller", version.ref = "accompanistSystemUi" } +google-mediapipe-image-generator = { group = "com.google.mediapipe", name = "tasks-vision-image-generator", version.ref = "mediaPipeGenerator" } retrofit-core = { group = "com.squareup.retrofit2", name = "retrofit", version.ref = "retrofit" } retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit" } retrofit-adapter-rxjava3 = { group = "com.squareup.retrofit2", name = "adapter-rxjava3", version.ref = "retrofit" } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index 7e50da66..bb634bf1 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -236,6 +236,7 @@ private fun ContentSettingsState( ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai ServerSource.LOCAL -> LocalizationR.string.srv_type_local_short + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe_short ServerSource.SWARM_UI -> LocalizationR.string.srv_type_swarm_ui }.asUiText(), onClick = { processIntent(SettingsIntent.NavigateConfiguration) }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 3677b6b3..3c5ee8a7 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -259,6 +259,7 @@ class ServerSetupViewModel( ServerSource.OPEN_AI -> connectToOpenAi() ServerSource.STABILITY_AI -> connectToStabilityAi() ServerSource.SWARM_UI -> connectToSwarmUi() + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> connectToMediaPipe() } .doOnSubscribe { setScreenModal(Modal.Communicating(canCancel = false)) } .subscribeOnMainThread(schedulersProvider) @@ -304,6 +305,10 @@ class ServerSetupViewModel( } } + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { + true + } + ServerSource.HUGGING_FACE -> { val validation = stringValidator(currentState.huggingFaceApiKey) updateState { @@ -410,6 +415,10 @@ class ServerSetupViewModel( return setupConnectionInterActor.connectToLocal(localModelId) } + private fun connectToMediaPipe(): Single> { + return setupConnectionInterActor.connectToMediaPipe() + } + private fun localModelDownloadClickReducer(localModel: ServerSetupState.LocalModel) { when { // User cancels download diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index 4d53ca89..28e602fa 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -58,6 +58,8 @@ fun ConfigurationStep( state = state, processIntent = processIntent, ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index ae6e5173..b4c4785d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -64,6 +64,7 @@ fun EngineSelectionComponent( displayDelegate = { it.name.asUiText() }, ) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit ServerSource.HORDE -> Unit ServerSource.OPEN_AI -> Unit } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index 1a8583d5..c1da2d53 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -295,6 +295,7 @@ fun GenerationInputForm( displayDelegate = { it.key.asUiText() }, ) } + else -> Unit } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt index fdca7744..3bab42ec 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -16,6 +16,7 @@ fun ServerSource.getNameUiText(): UiText = when (this) { ServerSource.AUTOMATIC1111 -> LocalizationR.string.srv_type_own ServerSource.HORDE -> LocalizationR.string.srv_type_horde ServerSource.LOCAL -> LocalizationR.string.srv_type_local + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai diff --git a/settings.gradle.kts b/settings.gradle.kts index 6ee0480b..fca65518 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,8 +1,8 @@ pluginManagement { includeBuild("build-logic") repositories { - google() mavenCentral() + google() gradlePluginPortal() maven { url = uri("https://jitpack.io") @@ -12,8 +12,8 @@ pluginManagement { dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { - google() mavenCentral() + google() maven { url = uri("https://jitpack.io") } @@ -35,6 +35,7 @@ val modules = listOf( ":domain", ":feature:auth", ":feature:diffusion", + ":feature:mediapipe", ":feature:work", ":network", ":presentation", From 233d89650a04e1851905acddcc19b39b093c6373 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 26 Aug 2024 11:50:55 +0300 Subject: [PATCH 02/10] Architecture finalization --- app/build.gradle.kts | 21 -- app/src/{dev => full}/AndroidManifest.xml | 0 .../aisdv1/app/di/ProvidersModule.kt | 6 +- build-logic/convention/build.gradle.kts | 4 + .../kotlin/ApplicationConventionPlugin.kt | 5 + .../main/kotlin/FlavorsConventionPlugin.kt | 16 ++ .../src/main/kotlin/JacocoConventionPlugin.kt | 2 +- .../shifthackz/aisdv1/buildlogic/Flavors.kt | 43 ++++ .../aisdv1/core/common/appbuild/BuildType.kt | 2 + .../src/main/res/values/strings.xml | 11 +- .../preference/PreferenceManagerImplTest.kt | 8 +- .../StabilityAiCreditsRepositoryImplTest.kt | 10 +- .../aisdv1/domain/entity/ServerSource.kt | 2 +- .../InterruptGenerationUseCaseImpl.kt | 2 +- .../generation/TextToImageUseCaseImpl.kt | 2 +- .../ConnectToLocalDiffusionUseCaseImpl.kt | 2 +- .../generation/ImageToImageUseCaseImplTest.kt | 2 +- .../InterruptGenerationUseCaseImplTest.kt | 4 +- .../generation/TextToImageUseCaseImplTest.kt | 6 +- .../splash/SplashNavigationUseCaseImplTest.kt | 4 +- feature/mediapipe/build.gradle.kts | 16 +- .../aisdv1/feature/mediapipe/MediaPipeImpl.kt | 13 ++ .../aisdv1/feature/mediapipe/MediaPipeImpl.kt | 34 +-- .../datatransport/AutoValue_Event.java | 27 --- .../android/datatransport/Encoding.java | 18 -- .../google/android/datatransport/Event.java | 48 ----- .../android/datatransport/Priority.java | 10 - .../android/datatransport/Transformer.java | 5 - .../android/datatransport/Transport.java | 7 - .../datatransport/TransportFactory.java | 8 - .../TransportScheduleCallback.java | 7 - .../datatransport/cct/CCTDestination.java | 6 - .../runtime/TransportRuntime.java | 12 -- .../runtime/backends/TransportBackend.java | 9 - .../google/common/flogger/FluentLogger.java | 9 - .../google/firebase/encoders/DataEncoder.java | 16 -- .../com/google/firebase/encoders/Encoder.java | 11 - .../firebase/encoders/EncodingException.java | 14 -- .../firebase/encoders/FieldDescriptor.java | 59 ----- .../firebase/encoders/ObjectEncoder.java | 3 - .../encoders/ObjectEncoderContext.java | 56 ----- .../firebase/encoders/ValueEncoder.java | 3 - .../encoders/ValueEncoderContext.java | 37 ---- .../encoders/config/Configurator.java | 7 - .../encoders/config/EncoderConfig.java | 14 -- .../encoders/json/JsonDataEncoderBuilder.java | 66 ------ .../json/JsonValueObjectEncoderContext.java | 201 ------------------ .../firebase/encoders/proto/AtProtobuf.java | 48 ----- .../proto/LengthCountingOutputStream.java | 27 --- .../firebase/encoders/proto/ProtoEnum.java | 7 - .../firebase/encoders/proto/Protobuf.java | 14 -- .../proto/ProtobufDataEncoderContext.java | 151 ------------- .../encoders/proto/ProtobufEncoder.java | 69 ------ .../proto/ProtobufValueEncoderContext.java | 64 ------ .../mediapipe/extensions/ModelPaths.kt | 2 - .../aisdv1/work/core/CoreGenerationWorker.kt | 2 +- gradle/libs.versions.toml | 1 + .../aisdv1/presentation/di/ViewModelModule.kt | 1 + .../screen/img2img/ImageToImageScreen.kt | 14 +- .../page/LocalDiffusionPageContent.kt | 2 +- .../screen/settings/SettingsScreen.kt | 4 +- .../screen/settings/SettingsState.kt | 4 +- .../screen/setup/ServerSetupScreen.kt | 4 +- .../screen/setup/ServerSetupViewModel.kt | 11 +- .../components/ConfigurationModeButton.kt | 6 +- .../screen/setup/steps/ConfigurationStep.kt | 2 +- .../screen/txt2img/TextToImageState.kt | 2 +- .../screen/txt2img/TextToImageViewModel.kt | 2 +- .../widget/engine/EngineSelectionComponent.kt | 2 +- .../widget/engine/EngineSelectionViewModel.kt | 2 +- .../widget/input/GenerationInputForm.kt | 58 ++--- .../widget/source/ServerSourceLabel.kt | 2 +- .../screen/setup/ServerSetupScreenTest.kt | 4 +- .../screen/setup/ServerSetupViewModelTest.kt | 6 +- .../engine/EngineSelectionViewModelTest.kt | 2 +- 75 files changed, 214 insertions(+), 1167 deletions(-) rename app/src/{dev => full}/AndroidManifest.xml (100%) create mode 100644 build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt create mode 100644 build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt create mode 100644 feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt rename feature/mediapipe/src/{main => full}/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt (67%) delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java delete mode 100644 feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java delete mode 100644 feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java delete mode 100644 feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java delete mode 100644 feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt diff --git a/app/build.gradle.kts b/app/build.gradle.kts index cb9d95fc..9ad3e14a 100755 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -51,27 +51,6 @@ android { println("[Signature] -> Build will be signed with signature: $alias") buildTypes.getByName("release").signingConfig = signingConfigs.getByName("release") } - - flavorDimensions += "type" - productFlavors { - create("dev") { - dimension = "type" - applicationIdSuffix = ".dev" - resValue("string", "app_name", "SDAI Dev") - buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"") - } - create("foss") { - dimension = "type" - applicationIdSuffix = ".foss" - resValue("string", "app_name", "SDAI FOSS") - buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"") - } - create("playstore") { - dimension = "type" - resValue("string", "app_name", "SDAI") - buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"") - } - } } dependencies { diff --git a/app/src/dev/AndroidManifest.xml b/app/src/full/AndroidManifest.xml similarity index 100% rename from app/src/dev/AndroidManifest.xml rename to app/src/full/AndroidManifest.xml diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 81c995de..4aa4ad84 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -130,7 +130,11 @@ val providersModule = module { append("$version") if (BuildConfig.DEBUG) append("-dev") append(" ($buildNumber)") - if (type == BuildType.FOSS) append(" FOSS") + when (type) { + BuildType.FULL -> append(" FULL") + BuildType.FOSS -> append(" FOSS") + BuildType.PLAY -> Unit + } } } } diff --git a/build-logic/convention/build.gradle.kts b/build-logic/convention/build.gradle.kts index 75b16dee..2d730493 100644 --- a/build-logic/convention/build.gradle.kts +++ b/build-logic/convention/build.gradle.kts @@ -36,6 +36,10 @@ gradlePlugin { id = "generic.application" implementationClass = "ApplicationConventionPlugin" } + register("Flavors") { + id = "generic.flavors" + implementationClass = "FlavorsConventionPlugin" + } register("BaselineProFm") { id = "generic.baseline.profm" implementationClass = "BaselineProFmConventionPlugin" diff --git a/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt b/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt index 2b126264..34cbad9a 100644 --- a/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt +++ b/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt @@ -1,6 +1,8 @@ import com.android.build.api.dsl.ApplicationExtension +import com.android.build.gradle.BaseExtension import com.shifthackz.aisdv1.buildlogic.configureApplication import com.shifthackz.aisdv1.buildlogic.configureCompose +import com.shifthackz.aisdv1.buildlogic.configureFlavors import com.shifthackz.aisdv1.buildlogic.libs import org.gradle.api.Plugin import org.gradle.api.Project @@ -22,6 +24,9 @@ class ApplicationConventionPlugin : Plugin { configureCompose(this) defaultConfig.targetSdk = libs.findVersion("targetSdk").get().toString().toInt() } + extensions.configure { + configureFlavors(this) + } } } } diff --git a/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt b/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt new file mode 100644 index 00000000..743465f1 --- /dev/null +++ b/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt @@ -0,0 +1,16 @@ +import com.android.build.gradle.LibraryExtension +import com.shifthackz.aisdv1.buildlogic.configureFlavorsCommon +import org.gradle.api.Plugin +import org.gradle.api.Project +import org.gradle.kotlin.dsl.configure + +class FlavorsConventionPlugin : Plugin { + + override fun apply(target: Project) { + with(target) { + extensions.configure { + configureFlavorsCommon(this) + } + } + } +} diff --git a/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt b/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt index 5bc8f1b3..6e83a3fc 100644 --- a/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt +++ b/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt @@ -7,7 +7,7 @@ import org.gradle.kotlin.dsl.configure import org.gradle.kotlin.dsl.withType import org.gradle.testing.jacoco.plugins.JacocoTaskExtension -class JacocoConventionPlugin : Plugin { +class JacocoConventionPlugin : Plugin { override fun apply(target: Project) { with(target) { diff --git a/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt new file mode 100644 index 00000000..6357b9ff --- /dev/null +++ b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.buildlogic + +import com.android.build.api.dsl.CommonExtension +import com.android.build.gradle.BaseExtension +import org.gradle.api.Project + +internal fun Project.configureFlavors( + commonExtension: BaseExtension, +) { + commonExtension.apply { + flavorDimensions("type") + productFlavors.create("full") { + dimension = "type" + applicationIdSuffix = ".full" + resValue("string", "app_name", "SDAI Full") + buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FULL\"") + + } + productFlavors.create("foss") { + dimension = "type" + applicationIdSuffix = ".foss" + resValue("string", "app_name", "SDAI FOSS") + buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"") + + } + productFlavors.create("playstore") { + dimension = "type" + resValue("string", "app_name", "SDAI") + buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"") + } + } +} + +internal fun Project.configureFlavorsCommon( + commonExtension: CommonExtension<*, *, *, *, *, *>, +) { + commonExtension.apply { + flavorDimensions += listOf("type") + productFlavors.create("full") { dimension = "type" } + productFlavors.create("foss") { dimension = "type" } + productFlavors.create("playstore") { dimension = "type" } + } +} diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt index d431d0f8..2f4c4471 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt @@ -1,11 +1,13 @@ package com.shifthackz.aisdv1.core.common.appbuild enum class BuildType { + FULL, FOSS, PLAY; companion object { fun fromBuildConfig(input: String) = when (input) { + "FULL" -> FULL "FOSS" -> FOSS else -> PLAY } diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml index e6aca262..a4f11e11 100755 --- a/core/localization/src/main/res/values/strings.xml +++ b/core/localization/src/main/res/values/strings.xml @@ -70,9 +70,9 @@ A1111 Horde AI Cloud Horde - Local Diffusion (Beta) - Local - Google AI MediaPipe + Local Diffusion ONNX (Beta) + Local ONNX + Local Google AI MediaPipe MediaPipe Hugging Face Inference HuggingFace @@ -153,9 +153,12 @@ A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility. Local Diffusion - This configuration allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. + + This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). + This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. + Web UI Txt2Img diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index 0535d7f8..b1d953ae 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -223,17 +223,17 @@ class PreferenceManagerImplTest { Assert.assertEquals(ServerSource.AUTOMATIC1111, preferenceManager.source) whenever(stubPreference.getString(eq(KEY_SERVER_SOURCE), any())) - .thenReturn(ServerSource.LOCAL.key) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX.key) - preferenceManager.source = ServerSource.LOCAL + preferenceManager.source = ServerSource.LOCAL_MICROSOFT_ONNX - Assert.assertEquals(ServerSource.LOCAL, preferenceManager.source) + Assert.assertEquals(ServerSource.LOCAL_MICROSOFT_ONNX, preferenceManager.source) preferenceManager .observe() .test() .assertNoErrors() - .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL } + .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL_MICROSOFT_ONNX } } @Test diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt index 3cc310df..7ef7f17a 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt @@ -40,7 +40,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to fetch, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubRemoteDataSource.fetch() @@ -62,7 +62,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to fetch and get, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubRemoteDataSource.fetch() @@ -88,7 +88,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to fetch and observe, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubRemoteDataSource.fetch() @@ -110,7 +110,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to get, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubLocalDataSource.get() @@ -128,7 +128,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to observe, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX repository .observe() diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 3ebf9ac9..bd8af79e 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -62,7 +62,7 @@ enum class ServerSource( FeatureTag.Batch, ), ), - LOCAL( + LOCAL_MICROSOFT_ONNX( key = "local", featureTags = setOf( FeatureTag.Offline, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt index 0f769197..37f6720c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt @@ -17,7 +17,7 @@ internal class InterruptGenerationUseCaseImpl( override fun invoke() = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.interruptGeneration() ServerSource.HORDE -> hordeGenerationRepository.interruptGeneration() - ServerSource.LOCAL -> localDiffusionGenerationRepository.interruptGeneration() + ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.interruptGeneration() else -> Completable.complete() } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index 2d38a598..f823a153 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -36,7 +36,7 @@ internal class TextToImageUseCaseImpl( private fun generate(payload: TextToImagePayload) = when (preferenceManager.source) { ServerSource.HORDE -> hordeGenerationRepository.generateFromText(payload) - ServerSource.LOCAL -> localDiffusionGenerationRepository.generateFromText(payload) + ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.generateFromText(payload) ServerSource.HUGGING_FACE -> huggingFaceGenerationRepository.generateFromText(payload) ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload) ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt index 0bbed0c9..5d89e78c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt @@ -11,7 +11,7 @@ internal class ConnectToLocalDiffusionUseCaseImpl( override fun invoke(modelId: String) = getConfigurationUseCase() .map { originalConfiguration -> originalConfiguration.copy( - source = ServerSource.LOCAL, + source = ServerSource.LOCAL_MICROSOFT_ONNX, localModelId = modelId, ) } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt index 9f2fac3b..5641d4f2 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt @@ -293,7 +293,7 @@ class ImageToImageUseCaseImplTest { @Test fun `given source is LOCAL, expected Img2Img not yet supported error`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) useCase(mockImageToImagePayload) .test() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt index 5ef6334c..895bb8a8 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt @@ -88,7 +88,7 @@ class InterruptGenerationUseCaseImplTest { @Test fun `given source is LOCAL, api interrupt success, expected complete value`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) .thenReturn(Completable.complete()) @@ -103,7 +103,7 @@ class InterruptGenerationUseCaseImplTest { @Test fun `given source is LOCAL, api interrupt fail, expected error value`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) .thenReturn(Completable.error(stubException)) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index b6d3bced..b028176d 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -363,7 +363,7 @@ class TextToImageUseCaseImplTest { @Test fun `given source is LOCAL, batch count is 1, generated successfully, expected generations list with size 1`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) .thenReturn(Single.just(mockAiGenerationResult)) @@ -386,7 +386,7 @@ class TextToImageUseCaseImplTest { @Test fun `given source is LOCAL, batch count is 10, generated successfully, expected generations list with size 10`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) .thenReturn(Single.just(mockAiGenerationResult)) @@ -409,7 +409,7 @@ class TextToImageUseCaseImplTest { @Test fun `given source is LOCAL, batch count is 1, generate failed, expected error`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) .thenReturn(Single.error(stubException)) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt index 1c2b3a0d..092b4c19 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt @@ -89,7 +89,7 @@ class SplashNavigationUseCaseImplTest { .thenReturn("") whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) useCase() .test() @@ -109,7 +109,7 @@ class SplashNavigationUseCaseImplTest { .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) useCase() .test() diff --git a/feature/mediapipe/build.gradle.kts b/feature/mediapipe/build.gradle.kts index 9612f9b1..0249feed 100644 --- a/feature/mediapipe/build.gradle.kts +++ b/feature/mediapipe/build.gradle.kts @@ -1,9 +1,10 @@ plugins { alias(libs.plugins.generic.library) + alias(libs.plugins.generic.flavors) } android { - namespace = "com.shifthackz.aisdv1.feature.diffusion" + namespace = "com.shifthackz.aisdv1.feature.mediapipe" } dependencies { @@ -11,15 +12,6 @@ dependencies { implementation(project(":domain")) implementation(libs.koin.core) implementation(libs.rx.kotlin) - implementation(libs.google.mediapipe.image.generator) { - exclude(group = "com.google.firebase", module = "firebase-encoders") - exclude(group = "com.google.firebase", module = "firebase-encoders-json") - exclude(group = "com.google.firebase", module = "firebase-encoders-proto") - exclude(group = "com.google.flogger", module = "flogger") - exclude(group = "com.google.flogger", module = "flogger-system-backend") - exclude(group = "com.google.android.datatransport", module = "transport-api") - exclude(group = "com.google.android.datatransport", module = "transport-backend-cct") - exclude(group = "com.google.android.datatransport", module = "transport-runtime") - } - implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1") + fullImplementation(libs.google.mediapipe.image.generator) + playstoreImplementation(libs.google.mediapipe.image.generator) } diff --git a/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt new file mode 100644 index 00000000..0defdf78 --- /dev/null +++ b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -0,0 +1,13 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import android.content.Context +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import io.reactivex.rxjava3.core.Single + +internal class MediaPipeImpl : MediaPipe { + + override fun process(payload: TextToImagePayload): Single = Single.error(Throwable("null")) +} diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt similarity index 67% rename from feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt rename to feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt index 32e898fb..c75caad8 100644 --- a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt +++ b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -18,23 +18,23 @@ internal class MediaPipeImpl( private var imageGenerator: ImageGenerator? = null override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> - try { - initialize() - println("Generating...") - val result = imageGenerator?.generate( - payload.prompt, - payload.samplingSteps, - payload.seed.toIntOrNull() ?: 0, - ) - println("Extracting bitmap...") - val bitmap = BitmapExtractor.extract(result?.generatedImage()) - println("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") - close() - if (!emitter.isDisposed) emitter.onSuccess(bitmap) - } catch (e: Exception) { - close() - if (!emitter.isDisposed) emitter.onError(e) - } + try { + initialize() + println("Generating...") + val result = imageGenerator?.generate( + payload.prompt, + payload.samplingSteps, + payload.seed.toIntOrNull() ?: 0, + ) + println("Extracting bitmap...") + val bitmap = BitmapExtractor.extract(result?.generatedImage()) + println("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") + close() + if (!emitter.isDisposed) emitter.onSuccess(bitmap) + } catch (e: Exception) { + close() + if (!emitter.isDisposed) emitter.onError(e) + } } private fun initialize(): ImageGenerator { diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java deleted file mode 100644 index 58ecd456..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/AutoValue_Event.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.google.android.datatransport; - -import androidx.annotation.Nullable; - -public class AutoValue_Event extends Event { - - AutoValue_Event(@Nullable Integer code, T payload, Priority priority) { - - - } - - @Nullable - @Override - public Integer getCode() { - return 0; - } - - @Override - public T getPayload() { - return null; - } - - @Override - public Priority getPriority() { - return null; - } -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java deleted file mode 100644 index 4cbbb057..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/Encoding.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.google.android.datatransport; - -import androidx.annotation.NonNull; - -public final class Encoding { - - public static Encoding of(@NonNull String name) { - return new Encoding(name); - } - - public String getName() { - return ""; - } - - private Encoding(@NonNull String name) { - - } -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java deleted file mode 100644 index 5cbd01c0..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/Event.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.google.android.datatransport; - - -import androidx.annotation.Nullable; - -public abstract class Event { - public Event() { - } - - @Nullable - public abstract Integer getCode(); - - @Nullable - public abstract T getPayload(); - - @Nullable - public abstract Priority getPriority(); - - @Nullable - public static Event ofData(int code, T payload) { - return null; - } - - @Nullable - public static Event ofData(T payload) { - return null; - } - - @Nullable - public static Event ofTelemetry(int code, T value) { - return null; - } - - @Nullable - public static Event ofTelemetry(T value) { - return null; - } - - @Nullable - public static Event ofUrgent(int code, T value) { - return null; - } - - @Nullable - public static Event ofUrgent(T value) { - return null; - } -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java deleted file mode 100644 index 6bef1467..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/Priority.java +++ /dev/null @@ -1,10 +0,0 @@ -package com.google.android.datatransport; - -public enum Priority { - DEFAULT, - VERY_LOW, - HIGHEST; - - private Priority() { - } -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java deleted file mode 100644 index 426d2986..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/Transformer.java +++ /dev/null @@ -1,5 +0,0 @@ -package com.google.android.datatransport; - -public interface Transformer { - U apply(T var1); -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java deleted file mode 100644 index 5bdf02d2..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/Transport.java +++ /dev/null @@ -1,7 +0,0 @@ -package com.google.android.datatransport; - -public interface Transport { - void send(Event var1); - - void schedule(Event var1, TransportScheduleCallback var2); -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java deleted file mode 100644 index a092b671..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportFactory.java +++ /dev/null @@ -1,8 +0,0 @@ -package com.google.android.datatransport; - -public interface TransportFactory { - @Deprecated - Transport getTransport(String var1, Class var2, Transformer var3); - - Transport getTransport(String var1, Class var2, Encoding var3, Transformer var4); -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java deleted file mode 100644 index 8a06f9aa..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/TransportScheduleCallback.java +++ /dev/null @@ -1,7 +0,0 @@ -package com.google.android.datatransport; - -import androidx.annotation.Nullable; - -public interface TransportScheduleCallback { - void onSchedule(@Nullable Exception var1); -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java deleted file mode 100644 index 6e5fd45a..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/cct/CCTDestination.java +++ /dev/null @@ -1,6 +0,0 @@ -package com.google.android.datatransport.cct; - -public class CCTDestination { - - public static CCTDestination INSTANCE = new CCTDestination(); -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java deleted file mode 100644 index 2c15ecf9..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/TransportRuntime.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.google.android.datatransport.runtime; - -import android.content.Context; - -public class TransportRuntime { - - public static void initialize(Context applicationContext) {} - - public static TransportRuntime getInstance() { - return new TransportRuntime(); - } -} diff --git a/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java b/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java deleted file mode 100644 index e2e45b07..00000000 --- a/feature/mediapipe/src/main/java/com/google/android/datatransport/runtime/backends/TransportBackend.java +++ /dev/null @@ -1,9 +0,0 @@ -//package com.google.android.datatransport.runtime.backends; -// -//import com.google.android.datatransport.runtime.EventInternal; -// -//public interface TransportBackend { -// EventInternal decorate(EventInternal var1); -// -// BackendResponse send(BackendRequest var1); -//} diff --git a/feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java b/feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java deleted file mode 100644 index b1b88f39..00000000 --- a/feature/mediapipe/src/main/java/com/google/common/flogger/FluentLogger.java +++ /dev/null @@ -1,9 +0,0 @@ -package com.google.common.flogger; - -public class FluentLogger { - - public static FluentLogger forEnclosingClass() { - return new FluentLogger(); - } -} -//com/google/firebase/encoders/json/JsonDataEncoderBuilder; \ No newline at end of file diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java deleted file mode 100644 index 2cab88ee..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/DataEncoder.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.google.firebase.encoders; - -import androidx.annotation.NonNull; - -import java.io.IOException; -import java.io.Writer; - -public interface DataEncoder { - - /** Encodes {@code obj} into {@code writer}. */ - void encode(@NonNull Object obj, @NonNull Writer writer) throws IOException; - - /** Returns the string-encoded representation of {@code obj}. */ - @NonNull - String encode(@NonNull Object obj); -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java deleted file mode 100644 index 25568ade..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/Encoder.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.google.firebase.encoders; - -import androidx.annotation.NonNull; - -import java.io.IOException; - -interface Encoder { - - /** Encode {@code obj} using {@code TContext}. */ - void encode(@NonNull TValue obj, @NonNull TContext context) throws IOException; -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java deleted file mode 100644 index 9ad3c896..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/EncodingException.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.google.firebase.encoders; - -import androidx.annotation.NonNull; - -public final class EncodingException extends RuntimeException { - - public EncodingException(@NonNull String message) { - super(message); - } - - public EncodingException(@NonNull String message, @NonNull Exception cause) { - super(message, cause); - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java deleted file mode 100644 index 1bad9234..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/FieldDescriptor.java +++ /dev/null @@ -1,59 +0,0 @@ -package com.google.firebase.encoders; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import java.lang.annotation.Annotation; -import java.util.Collections; -import java.util.Map; - -public final class FieldDescriptor { - - private FieldDescriptor(String name, Map, Object> properties) { - } - - /** Name of the field. */ - @NonNull - public String getName() { - return ""; - } - - /** - * Provides access to extra properties of the field. - * - * @return {@code T} annotation if present, null otherwise. - */ - @Nullable - @SuppressWarnings("unchecked") - public T getProperty(@NonNull Class type) { - return null; - } - - @NonNull - public static FieldDescriptor of(@NonNull String name) { - return new FieldDescriptor(name, Collections.emptyMap()); - } - - @NonNull - public static Builder builder(@NonNull String name) { - return new Builder(name); - } - - public static final class Builder { - - - Builder(String name) { - - } - - @NonNull - public Builder withProperty(@NonNull T value) { - return this; - } - - @NonNull - public FieldDescriptor build() { - return new FieldDescriptor("", Collections.emptyMap()); - } - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java deleted file mode 100644 index 12271efc..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoder.java +++ /dev/null @@ -1,3 +0,0 @@ -package com.google.firebase.encoders; - -public interface ObjectEncoder extends Encoder {} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java deleted file mode 100644 index 77936279..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ObjectEncoderContext.java +++ /dev/null @@ -1,56 +0,0 @@ -package com.google.firebase.encoders; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import java.io.IOException; - -public interface ObjectEncoderContext { - - @Deprecated - @NonNull - ObjectEncoderContext add(@NonNull String name, @Nullable Object obj) throws IOException; - - @Deprecated - @NonNull - ObjectEncoderContext add(@NonNull String name, double value) throws IOException; - - @Deprecated - @NonNull - ObjectEncoderContext add(@NonNull String name, int value) throws IOException; - - @Deprecated - @NonNull - ObjectEncoderContext add(@NonNull String name, long value) throws IOException; - - @Deprecated - @NonNull - ObjectEncoderContext add(@NonNull String name, boolean value) throws IOException; - - @NonNull - ObjectEncoderContext add(@NonNull FieldDescriptor field, @Nullable Object obj) throws IOException; - - @NonNull - ObjectEncoderContext add(@NonNull FieldDescriptor field, float value) throws IOException; - - @NonNull - ObjectEncoderContext add(@NonNull FieldDescriptor field, double value) throws IOException; - - @NonNull - ObjectEncoderContext add(@NonNull FieldDescriptor field, int value) throws IOException; - - @NonNull - ObjectEncoderContext add(@NonNull FieldDescriptor field, long value) throws IOException; - - @NonNull - ObjectEncoderContext add(@NonNull FieldDescriptor field, boolean value) throws IOException; - - @NonNull - ObjectEncoderContext inline(@Nullable Object value) throws IOException; - - @NonNull - ObjectEncoderContext nested(@NonNull String name) throws IOException; - - @NonNull - ObjectEncoderContext nested(@NonNull FieldDescriptor field) throws IOException; -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java deleted file mode 100644 index 48ae1168..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoder.java +++ /dev/null @@ -1,3 +0,0 @@ -package com.google.firebase.encoders; - -public interface ValueEncoder extends Encoder {} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java deleted file mode 100644 index 85a72c4d..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/ValueEncoderContext.java +++ /dev/null @@ -1,37 +0,0 @@ -package com.google.firebase.encoders; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import java.io.IOException; - -public interface ValueEncoderContext { - - /** Adds {@code value} as a primitive encoded value. */ - @NonNull - ValueEncoderContext add(@Nullable String value) throws IOException; - - /** Adds {@code value} as a primitive encoded value. */ - @NonNull - ValueEncoderContext add(float value) throws IOException; - - /** Adds {@code value} as a primitive encoded value. */ - @NonNull - ValueEncoderContext add(double value) throws IOException; - - /** Adds {@code value} as a primitive encoded value. */ - @NonNull - ValueEncoderContext add(int value) throws IOException; - - /** Adds {@code value} as a primitive encoded value. */ - @NonNull - ValueEncoderContext add(long value) throws IOException; - - /** Adds {@code value} as a primitive encoded value. */ - @NonNull - ValueEncoderContext add(boolean value) throws IOException; - - /** Adds {@code value} as a encoded array of bytes. */ - @NonNull - ValueEncoderContext add(@NonNull byte[] bytes) throws IOException; -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java deleted file mode 100644 index 6332c0fe..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/Configurator.java +++ /dev/null @@ -1,7 +0,0 @@ -package com.google.firebase.encoders.config; - -import androidx.annotation.NonNull; - -public interface Configurator { - void configure(@NonNull EncoderConfig configuration); -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java deleted file mode 100644 index 308e28ba..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/config/EncoderConfig.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.google.firebase.encoders.config; - -import androidx.annotation.NonNull; - -import com.google.firebase.encoders.ObjectEncoder; -import com.google.firebase.encoders.ValueEncoder; - -public interface EncoderConfig> { - @NonNull - T registerEncoder(@NonNull Class type, @NonNull ObjectEncoder encoder); - - @NonNull - T registerEncoder(@NonNull Class type, @NonNull ValueEncoder encoder); -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java deleted file mode 100644 index 2f6ebf5a..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonDataEncoderBuilder.java +++ /dev/null @@ -1,66 +0,0 @@ -package com.google.firebase.encoders.json; - -import androidx.annotation.NonNull; - -import com.google.firebase.encoders.DataEncoder; -import com.google.firebase.encoders.ObjectEncoder; -import com.google.firebase.encoders.ValueEncoder; -import com.google.firebase.encoders.config.Configurator; -import com.google.firebase.encoders.config.EncoderConfig; - -import java.io.IOException; -import java.io.Writer; - -public class JsonDataEncoderBuilder implements EncoderConfig { - - public JsonDataEncoderBuilder() { - } - - @NonNull - @Override - public JsonDataEncoderBuilder registerEncoder( - @NonNull Class clazz, @NonNull ObjectEncoder objectEncoder) { - - return this; - } - - @NonNull - @Override - public JsonDataEncoderBuilder registerEncoder( - @NonNull Class clazz, @NonNull ValueEncoder encoder) { - - return this; - } - - /** Encoder used if no encoders are found among explicitly registered ones. */ - @NonNull - public JsonDataEncoderBuilder registerFallbackEncoder( - @NonNull ObjectEncoder fallbackEncoder) { - return this; - } - - @NonNull - public JsonDataEncoderBuilder configureWith(@NonNull Configurator config) { - return this; - } - - @NonNull - public JsonDataEncoderBuilder ignoreNullValues(boolean ignore) { - return this; - } - - @NonNull - public DataEncoder build() { - return new DataEncoder() { - @Override - public void encode(@NonNull Object o, @NonNull Writer writer) throws IOException { - - } - - @Override - public String encode(@NonNull Object o) { - return ""; - } - }; - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java deleted file mode 100644 index 1b697bf8..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/json/JsonValueObjectEncoderContext.java +++ /dev/null @@ -1,201 +0,0 @@ -package com.google.firebase.encoders.json; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import com.google.firebase.encoders.EncodingException; -import com.google.firebase.encoders.FieldDescriptor; -import com.google.firebase.encoders.ObjectEncoder; -import com.google.firebase.encoders.ObjectEncoderContext; -import com.google.firebase.encoders.ValueEncoder; -import com.google.firebase.encoders.ValueEncoderContext; - -import java.io.IOException; -import java.io.Writer; -import java.util.Map; - -public class JsonValueObjectEncoderContext implements ObjectEncoderContext, ValueEncoderContext { - - JsonValueObjectEncoderContext( - @NonNull Writer writer, - @NonNull Map, ObjectEncoder> objectEncoders, - @NonNull Map, ValueEncoder> valueEncoders, - ObjectEncoder fallbackEncoder, - boolean ignoreNullValues) { - - } - - JsonValueObjectEncoderContext() {} - - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@NonNull String name, @Nullable Object o) - throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@NonNull String name, double value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@NonNull String name, int value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@NonNull String name, long value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@NonNull String name, boolean value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, @Nullable Object obj) - throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, float value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, double value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, int value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, long value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, boolean value) - throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext inline(@Nullable Object value) throws IOException { - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext nested(@NonNull String name) throws IOException { - - return new JsonValueObjectEncoderContext(); - } - - @NonNull - @Override - public ObjectEncoderContext nested(@NonNull FieldDescriptor field) throws IOException { - return nested(field.getName()); - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@Nullable String value) throws IOException { - - return this; - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(float value) throws IOException { - - - return this; - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(double value) throws IOException { - - return this; - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(int value) throws IOException { - - return this; - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(long value) throws IOException { - - return this; - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(boolean value) throws IOException { - - return this; - } - - @NonNull - @Override - public JsonValueObjectEncoderContext add(@Nullable byte[] bytes) throws IOException { - - return this; - } - - @NonNull - JsonValueObjectEncoderContext add(@Nullable Object o, boolean inline) throws IOException { - - return new JsonValueObjectEncoderContext(); - } - - JsonValueObjectEncoderContext doEncode(ObjectEncoder encoder, Object o, boolean inline) - throws IOException { - return this; - } - - private boolean cannotBeInline(Object value) { - return true; - } - - void close() throws IOException { - - } - - private void maybeUnNest() throws IOException { - - } - - private JsonValueObjectEncoderContext internalAdd(@NonNull String name, @Nullable Object o) - throws IOException, EncodingException { - return add(o, false); - } - - private JsonValueObjectEncoderContext internalAddIgnoreNullValues( - @NonNull String name, @Nullable Object o) throws IOException, EncodingException { - return add(o, false); - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java deleted file mode 100644 index cdf7a87f..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/AtProtobuf.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.google.firebase.encoders.proto; - -import java.lang.annotation.Annotation; - -public class AtProtobuf { - public AtProtobuf() { - - } - - public AtProtobuf tag(int tag) { - - return this; - } - - public AtProtobuf intEncoding(Protobuf.IntEncoding intEncoding) { - - return this; - } - - public static AtProtobuf builder() { - return new AtProtobuf(); - } - - public Protobuf build() { - return new ProtobufImpl(); - } - - private static final class ProtobufImpl implements Protobuf { - - ProtobufImpl(int tag, Protobuf.IntEncoding intEncoding) { - - } - - ProtobufImpl() {} - - public Class annotationType() { - return Protobuf.class; - } - - public int tag() { - return 0; - } - - public Protobuf.IntEncoding intEncoding() { - return IntEncoding.DEFAULT; - } - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java deleted file mode 100644 index 0d13838f..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/LengthCountingOutputStream.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.google.firebase.encoders.proto; - -import androidx.annotation.NonNull; - -import java.io.OutputStream; - -public class LengthCountingOutputStream extends OutputStream { - - @Override - public void write(int b) { - - } - - @Override - public void write(byte[] b) { - - } - - @Override - public void write(@NonNull byte[] b, int off, int len) { - - } - - long getLength() { - return 0L; - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java deleted file mode 100644 index f3499ccf..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtoEnum.java +++ /dev/null @@ -1,7 +0,0 @@ -package com.google.firebase.encoders.proto; - -public interface ProtoEnum { - - /** Numeric representation of the Enum. */ - int getNumber(); -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java deleted file mode 100644 index f7bc6234..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/Protobuf.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.google.firebase.encoders.proto; - -public @interface Protobuf { - int tag(); - - /** Specifies numeric field encoding. */ - IntEncoding intEncoding() default IntEncoding.DEFAULT; - - enum IntEncoding { - DEFAULT, - SIGNED, - FIXED - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java deleted file mode 100644 index b9a063b2..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufDataEncoderContext.java +++ /dev/null @@ -1,151 +0,0 @@ -package com.google.firebase.encoders.proto; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import com.google.firebase.encoders.EncodingException; -import com.google.firebase.encoders.FieldDescriptor; -import com.google.firebase.encoders.ObjectEncoder; -import com.google.firebase.encoders.ObjectEncoderContext; -import com.google.firebase.encoders.ValueEncoder; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.Map; - -public class ProtobufDataEncoderContext implements ObjectEncoderContext { - - ProtobufDataEncoderContext( - OutputStream output, - Map, ObjectEncoder> objectEncoders, - Map, ValueEncoder> valueEncoders, - ObjectEncoder fallbackEncoder) { - - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull String name, @Nullable Object obj) throws IOException { - return add(FieldDescriptor.of(name), obj); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull String name, double value) throws IOException { - return add(FieldDescriptor.of(name), value); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull String name, int value) throws IOException { - return add(FieldDescriptor.of(name), value); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull String name, long value) throws IOException { - return add(FieldDescriptor.of(name), value); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull String name, boolean value) throws IOException { - return add(FieldDescriptor.of(name), value); - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, @Nullable Object obj) - throws IOException { - return add(field, obj, true); - } - - ObjectEncoderContext add( - @NonNull FieldDescriptor field, @Nullable Object obj, boolean skipDefault) - throws IOException { - - return this; - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, double value) throws IOException { - return add(field, value, true); - } - - ObjectEncoderContext add(@NonNull FieldDescriptor field, double value, boolean skipDefault) - throws IOException { - return this; - } - - @NonNull - @Override - public ObjectEncoderContext add(@NonNull FieldDescriptor field, float value) throws IOException { - - return add(field, value, true); - } - - ObjectEncoderContext add(@NonNull FieldDescriptor field, float value, boolean skipDefault) - throws IOException { - return this; - } - - @NonNull - @Override - public ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, int value) - throws IOException { - return add(field, value, true); - } - - ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, int value, boolean skipDefault) - throws IOException { - return this; - } - - @NonNull - @Override - public ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, long value) - throws IOException { - return add(field, value, true); - } - - ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, long value, boolean skipDefault) - throws IOException { - return this; - } - - @NonNull - @Override - public ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, boolean value) - throws IOException { - return add(field, value, true); - } - - ProtobufDataEncoderContext add(@NonNull FieldDescriptor field, boolean value, boolean skipDefault) - throws IOException { - return add(field, value ? 1 : 0, skipDefault); - } - - @NonNull - @Override - public ObjectEncoderContext inline(@Nullable Object value) throws IOException { - return encode(value); - } - - ProtobufDataEncoderContext encode(@Nullable Object value) throws IOException { - throw new EncodingException("No encoder for " + value.getClass()); - } - - @NonNull - @Override - public ObjectEncoderContext nested(@NonNull String name) throws IOException { - return nested(FieldDescriptor.of(name)); - } - - @NonNull - @Override - public ObjectEncoderContext nested(@NonNull FieldDescriptor field) throws IOException { - throw new EncodingException("nested() is not implemented for protobuf encoding."); - } - -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java deleted file mode 100644 index 1378abca..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufEncoder.java +++ /dev/null @@ -1,69 +0,0 @@ -package com.google.firebase.encoders.proto; - -import androidx.annotation.NonNull; - -import com.google.firebase.encoders.ObjectEncoder; -import com.google.firebase.encoders.ValueEncoder; -import com.google.firebase.encoders.config.Configurator; -import com.google.firebase.encoders.config.EncoderConfig; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.Map; - -public class ProtobufEncoder { - - ProtobufEncoder( - Map, ObjectEncoder> objectEncoders, - Map, ValueEncoder> valueEncoders, - ObjectEncoder fallbackEncoder) { - } - - ProtobufEncoder() {} - - /** Encodes an arbitrary object and directly writes into the output stream. */ - public void encode(@NonNull Object value, @NonNull OutputStream outputStream) throws IOException { - } - - /** Encodes an arbitrary object and returns it as a byte array. */ - @NonNull - public byte[] encode(@NonNull Object value) { - return new byte[0]; - } - - public static Builder builder() { - return new Builder(); - } - - public static final class Builder implements EncoderConfig { - - @NonNull - @Override - public Builder registerEncoder( - @NonNull Class type, @NonNull ObjectEncoder encoder) { - return this; - } - - @NonNull - @Override - public Builder registerEncoder( - @NonNull Class type, @NonNull ValueEncoder encoder) { - return this; - } - - @NonNull - public Builder registerFallbackEncoder(@NonNull ObjectEncoder fallbackEncoder) { - return this; - } - - @NonNull - public Builder configureWith(@NonNull Configurator config) { - config.configure(this); - return this; - } - - public ProtobufEncoder build() { - return new ProtobufEncoder(); - } - } -} diff --git a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java b/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java deleted file mode 100644 index dc2dcb33..00000000 --- a/feature/mediapipe/src/main/java/com/google/firebase/encoders/proto/ProtobufValueEncoderContext.java +++ /dev/null @@ -1,64 +0,0 @@ -package com.google.firebase.encoders.proto; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import com.google.firebase.encoders.FieldDescriptor; -import com.google.firebase.encoders.ValueEncoderContext; - -import java.io.IOException; - -public class ProtobufValueEncoderContext implements ValueEncoderContext { - - ProtobufValueEncoderContext(ProtobufDataEncoderContext objEncoderCtx) { - - } - - ProtobufValueEncoderContext() {} - - void resetContext(FieldDescriptor field, boolean skipDefault) { - - } - - @NonNull - @Override - public ValueEncoderContext add(@Nullable String value) throws IOException { - return this; - } - - @NonNull - @Override - public ValueEncoderContext add(float value) throws IOException { - return this; - } - - @NonNull - @Override - public ValueEncoderContext add(double value) throws IOException { - return this; - } - - @NonNull - @Override - public ValueEncoderContext add(int value) throws IOException { - return this; - } - - @NonNull - @Override - public ValueEncoderContext add(long value) throws IOException { - return this; - } - - @NonNull - @Override - public ValueEncoderContext add(boolean value) throws IOException { - return this; - } - - @NonNull - @Override - public ValueEncoderContext add(@NonNull byte[] bytes) throws IOException { - return this; - } -} diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt deleted file mode 100644 index a92380c8..00000000 --- a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/ModelPaths.kt +++ /dev/null @@ -1,2 +0,0 @@ -package com.shifthackz.aisdv1.feature.mediapipe.extensions - diff --git a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt index d3076dee..ea7504b4 100644 --- a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt +++ b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt @@ -112,7 +112,7 @@ internal abstract class CoreGenerationWorker( setForegroundNotification( title = title, body = subTitle, - canCancel = source != ServerSource.LOCAL, + canCancel = source != ServerSource.LOCAL_MICROSOFT_ONNX, ) } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index e35f09d0..d1a1645c 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -117,6 +117,7 @@ android-application = { id = "com.android.application", version.ref = "agp" } android-library = { id = "com.android.library", version.ref = "agp" } jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } jetbrains-kotlin-kapt = { id = "org.jetbrains.kotlin.kapt", version="unspecified" } +generic-flavors = { id = "generic.flavors", version = "unspecified" } generic-library = { id = "generic.library", version = "unspecified" } generic-baseline-profm = { id = "generic.baseline.profm", version = "unspecified" } generic-compose = { id = "generic.compose", version = "unspecified" } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt index 9d656073..a6060cb9 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt @@ -84,6 +84,7 @@ val viewModelModule = module { preferenceManager = get(), wakeLockInterActor = get(), mainRouter = get(), + buildInfoProvider = get(), ) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt index fe8eeec1..2eee81b6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt @@ -135,7 +135,7 @@ private fun ScreenContent( ) }, actions = { - if (state.mode != ServerSource.LOCAL) { + if (state.mode != ServerSource.LOCAL_MICROSOFT_ONNX) { IconButton( onClick = { processIntent( @@ -234,14 +234,14 @@ private fun ScreenContent( Text( modifier = Modifier.padding(top = 14.dp), text = stringResource( - if (state.mode == ServerSource.LOCAL) LocalizationR.string.local_no_img2img_support_sub_title + if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) LocalizationR.string.local_no_img2img_support_sub_title else LocalizationR.string.dalle_no_img2img_support_sub_title ), ) Text( modifier = Modifier.padding(top = 14.dp), text = stringResource( - if (state.mode == ServerSource.LOCAL) LocalizationR.string.local_no_img2img_support_sub_title_2 + if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) LocalizationR.string.local_no_img2img_support_sub_title_2 else LocalizationR.string.dalle_no_img2img_support_sub_title_2 ), ) @@ -251,7 +251,7 @@ private fun ScreenContent( }, bottomBar = { val isEnabled = when (state.mode) { - ServerSource.LOCAL, + ServerSource.LOCAL_MICROSOFT_ONNX, ServerSource.OPEN_AI -> true else -> !state.hasValidationErrors && !state.imageState.isEmpty @@ -271,7 +271,7 @@ private fun ScreenContent( keyboardController?.hide() when (state.mode) { ServerSource.OPEN_AI, - ServerSource.LOCAL -> processIntent(GenerationMviIntent.Configuration) + ServerSource.LOCAL_MICROSOFT_ONNX -> processIntent(GenerationMviIntent.Configuration) else -> { promptChipTextFieldState.value.text.takeIf(String::isNotBlank) @@ -292,7 +292,7 @@ private fun ScreenContent( }, enabled = isEnabled, ) { - if (state.mode != ServerSource.LOCAL) { + if (state.mode != ServerSource.LOCAL_MICROSOFT_ONNX) { Icon( modifier = Modifier.size(18.dp), imageVector = Icons.Default.AutoFixNormal, @@ -303,7 +303,7 @@ private fun ScreenContent( modifier = Modifier.padding(start = 8.dp), text = stringResource( id = when (state.mode) { - ServerSource.LOCAL, + ServerSource.LOCAL_MICROSOFT_ONNX, ServerSource.OPEN_AI -> LocalizationR.string.action_change_configuration else -> LocalizationR.string.action_generate diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt index fbfb6e79..9138f169 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt @@ -60,7 +60,7 @@ fun LocalDiffusionPageContent( modifier = localModifier, state = TextToImageState( onBoardingDemo = true, - mode = ServerSource.LOCAL, + mode = ServerSource.LOCAL_MICROSOFT_ONNX, advancedToggleButtonVisible = false, advancedOptionsVisible = true, formPromptTaggedInput = true, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index bb634bf1..e611a02e 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -235,7 +235,7 @@ private fun ContentSettingsState( ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face_short ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai - ServerSource.LOCAL -> LocalizationR.string.srv_type_local_short + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local_short ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe_short ServerSource.SWARM_UI -> LocalizationR.string.srv_type_swarm_ui }.asUiText(), @@ -257,7 +257,7 @@ private fun ContentSettingsState( endValueText = state.sdModelSelected.asUiText(), onClick = { processIntent(SettingsIntent.SdModel.OpenChooser) }, ) - if (state.showLocalUseNNAPI) { + if (state.showLocalMICROSOFTONNXUseNNAPI) { SettingsItem( modifier = itemModifier, loading = state.loading, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt index aff0ad43..312ecd55 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt @@ -37,8 +37,8 @@ data class SettingsState( val showStabilityAiCredits: Boolean get() = serverSource == ServerSource.STABILITY_AI - val showLocalUseNNAPI: Boolean - get() = serverSource == ServerSource.LOCAL + val showLocalMICROSOFTONNXUseNNAPI: Boolean + get() = serverSource == ServerSource.LOCAL_MICROSOFT_ONNX val showSdModelSelector: Boolean get() = serverSource == ServerSource.AUTOMATIC1111 diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index a83405f0..5635568b 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -153,7 +153,7 @@ fun ServerSetupScreenContent( onClick = { processIntent(ServerSetupIntent.MainButtonClick) }, enabled = when (state.step) { ServerSetupState.Step.CONFIGURE -> when (state.mode) { - ServerSource.LOCAL -> state.localModels.any { + ServerSource.LOCAL_MICROSOFT_ONNX -> state.localModels.any { it.downloaded && it.selected } @@ -168,7 +168,7 @@ fun ServerSetupScreenContent( id = when (state.step) { ServerSetupState.Step.SOURCE -> LocalizationR.string.next else -> when (state.mode) { - ServerSource.LOCAL -> LocalizationR.string.action_setup + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.action_setup else -> LocalizationR.string.action_connect } }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 3c5ee8a7..c0ee4f0e 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.presentation.screen.setup +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider @@ -49,6 +51,7 @@ class ServerSetupViewModel( private val preferenceManager: PreferenceManager, private val wakeLockInterActor: WakeLockInterActor, private val mainRouter: MainRouter, + private val buildInfoProvider: BuildInfoProvider, ) : MviRxViewModel() { override val initialState = ServerSetupState( @@ -253,7 +256,7 @@ class ServerSetupViewModel( emitEffect(ServerSetupEffect.HideKeyboard) !when (currentState.mode) { ServerSource.HORDE -> connectToHorde() - ServerSource.LOCAL -> connectToLocalDiffusion() + ServerSource.LOCAL_MICROSOFT_ONNX -> connectToLocalDiffusion() ServerSource.AUTOMATIC1111 -> connectToAutomaticInstance() ServerSource.HUGGING_FACE -> connectToHuggingFace() ServerSource.OPEN_AI -> connectToOpenAi() @@ -293,7 +296,7 @@ class ServerSetupViewModel( } } - ServerSource.LOCAL -> { + ServerSource.LOCAL_MICROSOFT_ONNX -> { if (currentState.localCustomModel) { val validation = filePathValidator(currentState.localCustomModelPath) updateState { @@ -305,7 +308,9 @@ class ServerSetupViewModel( } } - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> if (buildInfoProvider.type == BuildType.FOSS) { + false + } else { true } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index 19815b66..c84b0547 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -72,7 +72,8 @@ fun ConfigurationModeButton( ServerSource.OPEN_AI, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE -> Icons.Default.Cloud - ServerSource.LOCAL -> Icons.Default.Android + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Icons.Default.Android else -> Icons.Default.QuestionMark }, contentDescription = null, @@ -91,9 +92,10 @@ fun ConfigurationModeButton( ServerSource.HORDE -> LocalizationR.string.hint_server_horde_sub_title ServerSource.HUGGING_FACE -> LocalizationR.string.hint_hugging_face_sub_title ServerSource.OPEN_AI -> LocalizationR.string.hint_open_ai_sub_title - ServerSource.LOCAL -> LocalizationR.string.hint_local_diffusion_sub_title + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.hint_local_diffusion_sub_title ServerSource.STABILITY_AI -> LocalizationR.string.hint_stability_ai_sub_title ServerSource.SWARM_UI -> LocalizationR.string.hint_swarm_ui_sub_title + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.hint_mediapipe_sub_title else -> null } descriptionId?.let { resId -> diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index 28e602fa..eb8f59d2 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -33,7 +33,7 @@ fun ConfigurationStep( processIntent = processIntent, ) - ServerSource.LOCAL -> LocalDiffusionForm( + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalDiffusionForm( state = state, buildInfoProvider = buildInfoProvider, processIntent = processIntent, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt index 2dad0725..04466de6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt @@ -133,7 +133,7 @@ fun TextToImageState.mapToPayload(): TextToImagePayload = with(this) { subSeedStrength = subSeedStrength, sampler = selectedSampler, nsfw = if (mode == ServerSource.HORDE) nsfw else false, - batchCount = if (mode == ServerSource.LOCAL) 1 else batchCount, + batchCount = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) 1 else batchCount, style = openAiStyle.key.takeIf { mode == ServerSource.OPEN_AI && openAiModel == OpenAiModel.DALL_E_3 }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index 515c3248..5b5d96a3 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -67,7 +67,7 @@ class TextToImageViewModel( private val progressModal: Modal get() { - if (currentState.mode == ServerSource.LOCAL) { + if (currentState.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { return Modal.Generating(canCancel = preferenceManager.localDiffusionAllowCancel) } return Modal.Communicating() diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index b4c4785d..bf787e02 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -54,7 +54,7 @@ fun EngineSelectionComponent( onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, ) - ServerSource.LOCAL -> DropdownTextField( + ServerSource.LOCAL_MICROSOFT_ONNX -> DropdownTextField( label = LocalizationR.string.hint_sd_model.asUiText(), loading = state.loading, modifier = modifier, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index 294a7c6c..a1f0a0d6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -131,7 +131,7 @@ class EngineSelectionViewModel( ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value - ServerSource.LOCAL -> preferenceManager.localModelId = intent.value + ServerSource.LOCAL_MICROSOFT_ONNX -> preferenceManager.localModelId = intent.value else -> Unit } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index c1da2d53..fe9b4914 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -148,7 +148,7 @@ fun GenerationInputForm( ServerSource.SWARM_UI, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE, - ServerSource.LOCAL -> EngineSelectionComponent( + ServerSource.LOCAL_MICROSOFT_ONNX -> EngineSelectionComponent( modifier = Modifier .fillMaxWidth() .padding(top = 8.dp), @@ -206,7 +206,7 @@ fun GenerationInputForm( ServerSource.SWARM_UI, ServerSource.HUGGING_FACE, ServerSource.STABILITY_AI, - ServerSource.LOCAL -> { + ServerSource.LOCAL_MICROSOFT_ONNX -> { if (state.formPromptTaggedInput) { ChipTextFieldWithItem( modifier = Modifier @@ -256,7 +256,7 @@ fun GenerationInputForm( when (state.mode) { ServerSource.HORDE, - ServerSource.LOCAL -> { + ServerSource.LOCAL_MICROSOFT_ONNX -> { DropdownTextField( modifier = localModifier.padding(end = 4.dp), label = LocalizationR.string.width.asUiText(), @@ -498,9 +498,11 @@ fun GenerationInputForm( else -> Unit } + //Steps not available for open ai if (state.mode != ServerSource.OPEN_AI) { val stepsMax = when (state.mode) { - ServerSource.LOCAL -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX + ServerSource.LOCAL_MICROSOFT_ONNX -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX ServerSource.STABILITY_AI -> SAMPLING_STEPS_RANGE_STABILITY_AI_MAX else -> SAMPLING_STEPS_RANGE_MAX } @@ -520,24 +522,31 @@ fun GenerationInputForm( processIntent(GenerationMviIntent.Update.SamplingSteps(it.roundToInt())) }, ) + } - Text( - modifier = Modifier.padding(top = 8.dp), - text = stringResource( - LocalizationR.string.hint_cfg_scale, - "${state.cfgScale.roundTo(2)}", - ), - ) - SliderTextInputField( - value = state.cfgScale, - valueRange = (CFG_SCALE_RANGE_MIN * 1f)..(CFG_SCALE_RANGE_MAX * 1f), - valueDiff = 0.5f, - steps = abs(CFG_SCALE_RANGE_MAX - CFG_SCALE_RANGE_MIN) * 2 - 1, - sliderColors = sliderColors, - onValueChange = { - processIntent(GenerationMviIntent.Update.CfgScale(it)) - }, - ) + // CFG scale not available on open ai and google media pipe + when (state.mode) { + ServerSource.OPEN_AI, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit + else -> { + Text( + modifier = Modifier.padding(top = 8.dp), + text = stringResource( + LocalizationR.string.hint_cfg_scale, + "${state.cfgScale.roundTo(2)}", + ), + ) + SliderTextInputField( + value = state.cfgScale, + valueRange = (CFG_SCALE_RANGE_MIN * 1f)..(CFG_SCALE_RANGE_MAX * 1f), + valueDiff = 0.5f, + steps = abs(CFG_SCALE_RANGE_MAX - CFG_SCALE_RANGE_MIN) * 2 - 1, + sliderColors = sliderColors, + onValueChange = { + processIntent(GenerationMviIntent.Update.CfgScale(it)) + }, + ) + } } when (state.mode) { @@ -549,9 +558,10 @@ fun GenerationInputForm( else -> Unit } - // Batch is not available for Local Diffusion - if (state.mode != ServerSource.LOCAL) { - batchComponent() + // Batch is not available for any Local + when (state.mode) { + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.LOCAL_MICROSOFT_ONNX -> Unit + else -> batchComponent() } //Restore faces available only for A1111 if (state.mode == ServerSource.AUTOMATIC1111) { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt index 3bab42ec..4ee15fd4 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -15,7 +15,7 @@ fun ServerSource.getName(): String { fun ServerSource.getNameUiText(): UiText = when (this) { ServerSource.AUTOMATIC1111 -> LocalizationR.string.srv_type_own ServerSource.HORDE -> LocalizationR.string.srv_type_horde - ServerSource.LOCAL -> LocalizationR.string.srv_type_local + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt index c82a1dbc..b2485734 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt @@ -107,7 +107,7 @@ class ServerSetupScreenTest : CoreComposeTest { stubUiState.update { it.copy( step = ServerSetupState.Step.CONFIGURE, - mode = ServerSource.LOCAL + mode = ServerSource.LOCAL_MICROSOFT_ONNX ) } @@ -125,7 +125,7 @@ class ServerSetupScreenTest : CoreComposeTest { stubUiState.update { it.copy( step = ServerSetupState.Step.CONFIGURE, - mode = ServerSource.LOCAL, + mode = ServerSource.LOCAL_MICROSOFT_ONNX, localModels = mockLocalAiModels.mapToUi() ) } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index a75bad5b..12031b8d 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.presentation.screen.setup +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator import com.shifthackz.aisdv1.core.validation.path.FilePathValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidator @@ -67,6 +68,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { preferenceManager = stubPreferenceManager, wakeLockInterActor = stubWakeLockInterActor, mainRouter = stubMainRouter, + buildInfoProvider = BuildInfoProvider.stub, ) @Before @@ -317,8 +319,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() { @Test fun `given received UpdateServerMode intent, expected mode field in UI state is LOCAL`() { - viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL)) - val expected = ServerSource.LOCAL + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) + val expected = ServerSource.LOCAL_MICROSOFT_ONNX val actual = viewModel.state.value.mode Assert.assertEquals(expected, actual) } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt index 855a9128..5b00f390 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt @@ -229,7 +229,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest @Test fun `given received EngineSelectionIntent, source is LOCAL, expected localModelId changed in preference`() { - mockInitialData(DataTestCase.Mock, ServerSource.LOCAL) + mockInitialData(DataTestCase.Mock, ServerSource.LOCAL_MICROSOFT_ONNX) every { stubPreferenceManager::localModelId.set(any()) From f4c75272f69599e758a0284d52c210e5c75aff56 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Mon, 26 Aug 2024 12:01:16 +0300 Subject: [PATCH 03/10] Allowed modes fix --- .../com/shifthackz/aisdv1/domain/entity/ServerSource.kt | 4 ++++ .../presentation/screen/setup/ServerSetupViewModel.kt | 2 ++ .../presentation/screen/setup/mappers/ModesMapper.kt | 9 +++++++++ 3 files changed, 15 insertions(+) create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index bd8af79e..15244e1d 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -1,8 +1,11 @@ package com.shifthackz.aisdv1.domain.entity +import com.shifthackz.aisdv1.core.common.appbuild.BuildType + enum class ServerSource( val key: String, val featureTags: Set, + val allowedInBuilds: Set = setOf(BuildType.FOSS, BuildType.PLAY, BuildType.FULL), ) { AUTOMATIC1111( key = "custom", @@ -77,6 +80,7 @@ enum class ServerSource( FeatureTag.Txt2Img, FeatureTag.MultipleModels, ), + allowedInBuilds = setOf(BuildType.PLAY, BuildType.FULL), ); companion object { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index c0ee4f0e..1aa89611 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -27,6 +27,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.presentation.model.LaunchSource import com.shifthackz.aisdv1.presentation.model.Modal import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomModelSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState @@ -94,6 +95,7 @@ class ServerSetupViewModel( localCustomModel = localModels.mapLocalCustomModelSwitchState(), localCustomModelPath = configuration.localModelPath, mode = configuration.source, + allowedModes = buildInfoProvider.allowedModes, demoMode = configuration.demoMode, serverUrl = configuration.serverUrl, swarmUiUrl = configuration.swarmUiUrl, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt new file mode 100644 index 00000000..1dd0b9c2 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.mappers + +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.domain.entity.ServerSource + +val BuildInfoProvider.allowedModes: List + get() = ServerSource + .entries + .filter { it.allowedInBuilds.contains(type) } From e3d146bcf4a5d95829d9cb136a4c065148f6e0e5 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Tue, 27 Aug 2024 12:03:29 +0300 Subject: [PATCH 04/10] Working MediaPipe prototype --- .../aisdv1/app/di/ProvidersModule.kt | 5 +- .../common/file/FileProviderDescriptor.kt | 1 - .../src/main/res/values/strings.xml | 11 +- .../local/DownloadableModelLocalDataSource.kt | 86 +++--- .../data/mappers/LocalAiModelMappers.kt | 14 +- .../data/preference/PreferenceManagerImpl.kt | 34 ++- .../DownloadableModelRemoteDataSource.kt | 18 +- .../DownloadableModelRepositoryImpl.kt | 18 +- .../LocalDiffusionGenerationRepositoryImpl.kt | 4 +- .../MediaPipeGenerationRepositoryImpl.kt | 4 +- .../DownloadableModelLocalDataSourceTest.kt | 102 ++----- .../data/mocks/LocalModelEntityMocks.kt | 2 + .../preference/PreferenceManagerImplTest.kt | 12 +- .../DownloadableModelRemoteDataSourceTest.kt | 9 +- .../DownloadableModelRepositoryImplTest.kt | 129 +-------- ...alDiffusionGenerationRepositoryImplTest.kt | 10 +- .../datasource/DownloadableModelDataSource.kt | 9 +- .../aisdv1/domain/di/DomainModule.kt | 15 +- .../aisdv1/domain/entity/Configuration.kt | 6 +- .../aisdv1/domain/entity/LocalAiModel.kt | 21 +- .../domain/feature/mediapipe/MediaPipe.kt | 1 - .../domain/preference/PreferenceManager.kt | 12 +- .../repository/DownloadableModelRepository.kt | 8 +- .../GetLocalMediaPipeModelsUseCase.kt | 8 + .../GetLocalMediaPipeModelsUseCaseImpl.kt | 10 + ...seCase.kt => GetLocalOnnxModelsUseCase.kt} | 2 +- ...pl.kt => GetLocalOnnxModelsUseCaseImpl.kt} | 6 +- ...se.kt => ObserveLocalOnnxModelsUseCase.kt} | 2 +- ...t => ObserveLocalOnnxModelsUseCaseImpl.kt} | 6 +- .../ConnectToLocalDiffusionUseCaseImpl.kt | 2 +- .../settings/ConnectToMediaPipeUseCase.kt | 2 +- .../settings/ConnectToMediaPipeUseCaseImpl.kt | 4 +- .../settings/GetConfigurationUseCaseImpl.kt | 6 +- .../SetServerConfigurationUseCaseImpl.kt | 6 +- .../aisdv1/domain/mocks/ConfigurationMocks.kt | 4 +- .../aisdv1/domain/mocks/LocalAiModelMocks.kt | 2 +- ...t => GetLocalOnnxModelsUseCaseImplTest.kt} | 10 +- ... ObserveLocalOnnxModelsUseCaseImplTest.kt} | 10 +- .../GetConfigurationUseCaseImplTest.kt | 8 +- .../SetServerConfigurationUseCaseImplTest.kt | 4 +- .../extensions/LocalDiffusionPaths.kt | 4 +- .../aisdv1/feature/mediapipe/MediaPipeImpl.kt | 6 +- .../aisdv1/feature/mediapipe/MediaPipeImpl.kt | 22 +- .../extensions/MediaPipeModelPaths.kt | 17 ++ .../aisdv1/feature/mediapipe/MediaPipeImpl.kt | 63 +++++ .../aisdv1/work/core/CoreGenerationWorker.kt | 2 +- .../network/api/sdai/DownloadableModelsApi.kt | 9 +- .../api/sdai/DownloadableModelsApiImpl.kt | 5 +- .../aisdv1/presentation/di/ViewModelModule.kt | 3 +- .../screen/debug/DebugMenuViewModel.kt | 4 +- .../screen/settings/SettingsViewModel.kt | 2 +- .../screen/setup/ServerSetupScreen.kt | 6 +- .../screen/setup/ServerSetupState.kt | 33 ++- .../screen/setup/ServerSetupViewModel.kt | 221 ++++++++++----- .../screen/setup/forms/LocalDiffusionForm.kt | 40 ++- .../screen/setup/forms/MediaPipeForm.kt | 29 ++ .../screen/setup/mappers/LocalModelMappers.kt | 7 +- .../screen/setup/steps/ConfigurationStep.kt | 7 +- .../screen/txt2img/TextToImageViewModel.kt | 2 +- .../widget/engine/EngineSelectionViewModel.kt | 12 +- .../presentation/mocks/LocalAiModelMocks.kt | 2 +- .../screen/settings/SettingsViewModelTest.kt | 4 +- .../screen/setup/ServerSetupScreenTest.kt | 14 +- .../screen/setup/ServerSetupViewModelTest.kt | 23 +- .../engine/EngineSelectionViewModelTest.kt | 14 +- .../6.json | 254 ++++++++++++++++++ .../db/persistent/PersistentDatabase.kt | 22 +- .../persistent/contract/LocalModelContract.kt | 1 + .../db/persistent/dao/LocalModelDao.kt | 6 + .../db/persistent/entity/LocalModelEntity.kt | 2 + 70 files changed, 967 insertions(+), 492 deletions(-) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt rename domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{GetLocalAiModelsUseCase.kt => GetLocalOnnxModelsUseCase.kt} (84%) rename domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{GetLocalAiModelsUseCaseImpl.kt => GetLocalOnnxModelsUseCaseImpl.kt} (59%) rename domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{ObserveLocalAiModelsUseCase.kt => ObserveLocalOnnxModelsUseCase.kt} (83%) rename domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{ObserveLocalAiModelsUseCaseImpl.kt => ObserveLocalOnnxModelsUseCaseImpl.kt} (70%) rename domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{GetLocalAiModelsUseCaseImplTest.kt => GetLocalOnnxModelsUseCaseImplTest.kt} (84%) rename domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{ObserveLocalAiModelsUseCaseImplTest.kt => ObserveLocalOnnxModelsUseCaseImplTest.kt} (91%) create mode 100644 feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt create mode 100644 feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt create mode 100644 storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 4aa4ad84..c9b147fb 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -170,21 +170,20 @@ val providersModule = module { override val imagesCacheDirPath: String = "${androidApplication().cacheDir}/images" override val logsCacheDirPath: String = "${androidApplication().cacheDir}/logs" override val localModelDirPath: String = "${androidApplication().filesDir.absolutePath}/model" - override val mediaPipeDirPath: String = "${androidApplication().filesDir.absolutePath}/out4" override val workCacheDirPath: String = "${androidApplication().cacheDir}/work" } } single { DeviceNNAPIFlagProvider { - get().localUseNNAPI + get().localOnnxUseNNAPI .let { nnApi -> if (nnApi) LocalDiffusionFlag.NN_API else LocalDiffusionFlag.CPU } .let(LocalDiffusionFlag::value) } } single { - LocalModelIdProvider { get().localModelId } + LocalModelIdProvider { get().localOnnxModelId } } single { diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt index 2c4a6c4c..22ce7dbb 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt @@ -7,6 +7,5 @@ interface FileProviderDescriptor { val imagesCacheDirPath: String val logsCacheDirPath: String val localModelDirPath: String - val mediaPipeDirPath: String val workCacheDirPath: String } diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml index a4f11e11..1ed2ae87 100755 --- a/core/localization/src/main/res/values/strings.xml +++ b/core/localization/src/main/res/values/strings.xml @@ -70,10 +70,10 @@ A1111 Horde AI Cloud Horde - Local Diffusion ONNX (Beta) - Local ONNX - Local Google AI MediaPipe - MediaPipe + Local Diffusion Microsoft ONNX (Beta) + ONNX + Local Google AI MediaPipe (Beta) + MediaPipe Hugging Face Inference HuggingFace Open AI @@ -152,11 +152,12 @@ Provide your Swarm UI URL A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility. - Local Diffusion + Local Diffusion Microsoft ONNX This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). + Local Google AI MediaPipe This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Web UI diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 0e4b9f02..d5fdcb66 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File @@ -22,72 +23,92 @@ internal class DownloadableModelLocalDataSource( private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelDataSource.Local { - override fun getAll() = dao - .query() + override fun getAllOnnx() = dao + .queryByType(LocalAiModel.Type.ONNX.key) .map(List::mapEntityToDomain) .map { models -> buildList { addAll(models) - if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM) + if (buildInfoProvider.type != BuildType.PLAY) { + add(LocalAiModel.CustomOnnx) + } + } + } + .flatMap { models -> models.withLocalData() } + + override fun getAllMediaPipe(): Single> = dao + .queryByType(LocalAiModel.Type.MediaPipe.key) + .map(List::mapEntityToDomain) + .map { models -> + buildList { + addAll(models) + if (buildInfoProvider.type != BuildType.PLAY) { + add(LocalAiModel.CustomMediaPipe) + } } } .flatMap { models -> models.withLocalData() } override fun getById(id: String): Single { - val chain = if (id == LocalAiModel.CUSTOM.id) { - Single.just(LocalAiModel.CUSTOM) - } else { - dao + val chain = when (id) { + LocalAiModel.CustomOnnx.id -> Single.just(LocalAiModel.CustomOnnx) + LocalAiModel.CustomMediaPipe.id -> Single.just(LocalAiModel.CustomMediaPipe) + else -> dao .queryById(id) .map(LocalModelEntity::mapEntityToDomain) } - return chain.flatMap { model -> model.withLocalData() } } - override fun getSelected() = Single - .just(preferenceManager.localModelId) - .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) } + override fun getSelectedOnnx() = Single + .just(preferenceManager.localOnnxModelId) .flatMap(::getById) .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) } - override fun observeAll() = dao - .observe() + override fun observeAllOnnx(): Flowable> = dao + .observeByType(LocalAiModel.Type.ONNX.key) .map(List::mapEntityToDomain) .map { models -> buildList { addAll(models) - if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM) + if (buildInfoProvider.type != BuildType.PLAY) add(LocalAiModel.CustomOnnx) } } .flatMap { models -> models.withLocalData().toFlowable() } - override fun select(id: String) = Completable.fromAction { - preferenceManager.localModelId = id - } - override fun save(list: List) = list - .filter { it.id != LocalAiModel.CUSTOM.id } + .filter { it.id != LocalAiModel.CustomOnnx.id } .mapDomainToEntity() .let(dao::insertList) - override fun isDownloaded(id: String) = Single.create { emitter -> + override fun delete(id: String): Completable = Completable.fromAction { + getLocalModelDirectory(id).deleteRecursively() + } + + private fun isDownloaded(model: LocalAiModel) = Single.create { emitter -> try { - if (id == LocalAiModel.CUSTOM.id) { - if (!emitter.isDisposed) emitter.onSuccess(true) - } else { - val files = getLocalModelFiles(id) - if (!emitter.isDisposed) emitter.onSuccess(files.size == 4) + when (model.id) { + LocalAiModel.CustomOnnx.id, + LocalAiModel.CustomMediaPipe.id -> emitter.onSuccess(true) + + else -> { + val files = getLocalModelFiles(model.id) + when (model.type) { + LocalAiModel.Type.ONNX -> { + emitter.onSuccess(files.size == 4) + } + + LocalAiModel.Type.MediaPipe -> { + emitter.onSuccess(files.isNotEmpty()) + } + } + } } } catch (e: Exception) { if (!emitter.isDisposed) emitter.onSuccess(false) } } - override fun delete(id: String): Completable = Completable.fromAction { - getLocalModelDirectory(id).deleteRecursively() - } - private fun getLocalModelDirectory(id: String): File { return File("${fileProviderDescriptor.localModelDirPath}/${id}") } @@ -103,11 +124,14 @@ internal class DownloadableModelLocalDataSource( .flatMapSingle { model -> model.withLocalData() } .toList() - private fun LocalAiModel.withLocalData() = isDownloaded(id) + private fun LocalAiModel.withLocalData() = isDownloaded(this) .map { downloaded -> copy( downloaded = downloaded, - selected = preferenceManager.localModelId == id, + selected = when (this.type) { + LocalAiModel.Type.ONNX -> preferenceManager.localOnnxModelId == id + LocalAiModel.Type.MediaPipe -> preferenceManager.localMediaPipeModelId == id + }, ) } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt index d9261085..7d7033aa 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt @@ -5,12 +5,16 @@ import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity //region RAW --> DOMAIN -fun List.mapRawToCheckpointDomain(): List = - map(DownloadableModelResponse::mapRawToCheckpointDomain) +fun List.mapRawToCheckpointDomain( + type: LocalAiModel.Type, +): List = map { it.mapRawToCheckpointDomain(type) } -fun DownloadableModelResponse.mapRawToCheckpointDomain(): LocalAiModel = with(this) { +fun DownloadableModelResponse.mapRawToCheckpointDomain( + type: LocalAiModel.Type, +): LocalAiModel = with(this) { LocalAiModel( id = id ?: "", + type = type, name = name ?: "", size = size ?: "", sources = sources ?: emptyList(), @@ -23,7 +27,7 @@ fun List.mapDomainToEntity(): List = map(LocalAiModel::mapDomainToEntity) fun LocalAiModel.mapDomainToEntity(): LocalModelEntity = with(this) { - LocalModelEntity(id, name, size, sources) + LocalModelEntity(id, type.key, name, size, sources) } //endregion @@ -32,6 +36,6 @@ fun List.mapEntityToDomain(): List = map(LocalModelEntity::mapEntityToDomain) fun LocalModelEntity.mapEntityToDomain(): LocalAiModel = with(this) { - LocalAiModel(id, name, size, sources) + LocalAiModel(id, LocalAiModel.Type.parse(type), name, size, sources) } //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index e88fcf47..ac47d61e 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -60,7 +60,16 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } - override var localDiffusionCustomModelPath: String + override var localMediaPipeCustomModelPath: String + get() = preferences.getString( + KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH, + LOCAL_DIFFUSION_CUSTOM_PATH + ) ?: LOCAL_DIFFUSION_CUSTOM_PATH + set(value) = preferences.edit() + .putString(KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH, value) + .apply() + + override var localOnnxCustomModelPath: String get() = preferences.getString( KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, LOCAL_DIFFUSION_CUSTOM_PATH, @@ -69,14 +78,14 @@ class PreferenceManagerImpl( .putString(KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, value) .apply() - override var localDiffusionAllowCancel: Boolean + override var localOnnxAllowCancel: Boolean get() = preferences.getBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, false) set(value) = preferences.edit() .putBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, value) .apply() .also { onPreferencesChanged() } - override var localDiffusionSchedulerThread: SchedulersToken + override var localOnnxSchedulerThread: SchedulersToken get() = preferences .getInt(KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD, SchedulersToken.COMPUTATION.ordinal) .let { SchedulersToken.entries[it] } @@ -196,20 +205,27 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } - override var localModelId: String + override var localOnnxModelId: String get() = preferences.getString(KEY_LOCAL_MODEL_ID, "") ?: "" set(value) = preferences.edit() .putString(KEY_LOCAL_MODEL_ID, value) .apply() .also { onPreferencesChanged() } - override var localUseNNAPI: Boolean + override var localOnnxUseNNAPI: Boolean get() = preferences.getBoolean(KEY_LOCAL_NN_API, false) set(value) = preferences.edit() .putBoolean(KEY_LOCAL_NN_API, value) .apply() .also { onPreferencesChanged() } + override var localMediaPipeModelId: String + get() = preferences.getString(KEY_MEDIA_PIPE_MODEL_ID, "") ?: "" + set(value) = preferences.edit() + .putString(KEY_MEDIA_PIPE_MODEL_ID, value) + .apply() + .also { onPreferencesChanged() } + override var designUseSystemColorPalette: Boolean get() = preferences.getBoolean(KEY_DESIGN_DYNAMIC_COLORS, false) set(value) = preferences.edit() @@ -273,8 +289,8 @@ class PreferenceManagerImpl( sdModel = sdModel, demoMode = demoMode, developerMode = developerMode, - localDiffusionAllowCancel = localDiffusionAllowCancel, - localDiffusionSchedulerThread = localDiffusionSchedulerThread, + localDiffusionAllowCancel = localOnnxAllowCancel, + localDiffusionSchedulerThread = localOnnxSchedulerThread, monitorConnectivity = monitorConnectivity, backgroundGeneration = backgroundGeneration, autoSaveAiResults = autoSaveAiResults, @@ -283,7 +299,7 @@ class PreferenceManagerImpl( formPromptTaggedInput = formPromptTaggedInput, source = source, hordeApiKey = hordeApiKey, - localUseNNAPI = localUseNNAPI, + localUseNNAPI = localOnnxUseNNAPI, designUseSystemColorPalette = designUseSystemColorPalette, designUseSystemDarkTheme = designUseSystemDarkTheme, designDarkTheme = designDarkTheme, @@ -302,6 +318,7 @@ class PreferenceManagerImpl( const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_DEVELOPER_MODE = "key_developer_mode" const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path" + const val KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH = "key_mediapipe_custom_model_path" const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel" const val KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD = "key_local_diffusion_scheduler_thread" const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" @@ -319,6 +336,7 @@ class PreferenceManagerImpl( const val KEY_STABILITY_AI_ENGINE_ID_KEY = "key_stability_ai_engine_id_key" const val KEY_ON_BOARDING_COMPLETE = "key_on_boarding_complete" const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.6.2" + const val KEY_MEDIA_PIPE_MODEL_ID = "key_mediapipe_model_id" const val KEY_LOCAL_MODEL_ID = "key_local_model_id" const val KEY_LOCAL_NN_API = "key_local_nn_api" const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors" diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index a2e6b193..1a552534 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -5,10 +5,12 @@ import com.shifthackz.aisdv1.core.common.file.unzip import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single import java.io.File internal class DownloadableModelRemoteDataSource( @@ -16,12 +18,18 @@ internal class DownloadableModelRemoteDataSource( private val fileProviderDescriptor: FileProviderDescriptor, ) : DownloadableModelDataSource.Remote { - override fun fetch() = api - .fetchDownloadableModels() - .map(List::mapRawToCheckpointDomain) + override fun fetch(): Single> = Single.zip( + api + .fetchOnnxModels() + .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) }, + api + .fetchMediaPipeModels() + .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe) }, + ::Pair, + ) + .map { (onnx, mediapipe) -> listOf(onnx, mediapipe).flatten() } - override fun download(id: String, url: String): Observable = - Completable + override fun download(id: String, url: String): Observable = Completable .fromAction { val dir = File("${fileProviderDescriptor.localModelDirPath}/${id}") val destination = File(getDestinationPath(id)) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt index 5bde3662..90e11c7a 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt @@ -8,8 +8,6 @@ internal class DownloadableModelRepositoryImpl( private val localDataSource: DownloadableModelDataSource.Local, ) : DownloadableModelRepository { - override fun isModelDownloaded(id: String) = localDataSource.isDownloaded(id) - override fun download(id: String) = localDataSource .getById(id) .flatMapObservable { model -> @@ -18,15 +16,17 @@ internal class DownloadableModelRepositoryImpl( override fun delete(id: String) = localDataSource.delete(id) - override fun getAll() = remoteDataSource + override fun getAllOnnx() = remoteDataSource .fetch() .flatMapCompletable(localDataSource::save) - .andThen(localDataSource.getAll()) - .onErrorResumeNext { localDataSource.getAll() } - - override fun getById(id: String) = localDataSource.getById(id) + .andThen(localDataSource.getAllOnnx()) + .onErrorResumeNext { localDataSource.getAllOnnx() } - override fun observeAll() = localDataSource.observeAll() + override fun getAllMediaPipe() = remoteDataSource + .fetch() + .flatMapCompletable(localDataSource::save) + .andThen(localDataSource.getAllMediaPipe()) + .onErrorResumeNext { localDataSource.getAllMediaPipe() } - override fun select(id: String) = localDataSource.select(id) + override fun observeAllOnnx() = localDataSource.observeAllOnnx() } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt index 0254b3d8..625a1ca8 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt @@ -36,7 +36,7 @@ internal class LocalDiffusionGenerationRepositoryImpl( override fun observeStatus() = localDiffusion.observeStatus() override fun generateFromText(payload: TextToImagePayload) = downloadableLocalDataSource - .getSelected() + .getSelectedOnnx() .flatMap { model -> if (model.downloaded) generate(payload) else Single.error(IllegalStateException("Model not downloaded.")) @@ -46,7 +46,7 @@ internal class LocalDiffusionGenerationRepositoryImpl( private fun generate(payload: TextToImagePayload) = localDiffusion .process(payload) - .subscribeOn(schedulersProvider.byToken(preferenceManager.localDiffusionSchedulerThread)) + .subscribeOn(schedulersProvider.byToken(preferenceManager.localOnnxSchedulerThread)) .map(BitmapToBase64Converter::Input) .flatMap(bitmapToBase64Converter::invoke) .map(BitmapToBase64Converter.Output::base64ImageString) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt index 4bac9dcc..88471209 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.data.repository +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter import com.shifthackz.aisdv1.data.core.CoreGenerationRepository @@ -21,6 +22,7 @@ internal class MediaPipeGenerationRepositoryImpl( localDataSource: GenerationResultDataSource.Local, backgroundWorkObserver: BackgroundWorkObserver, preferenceManager: PreferenceManager, + private val schedulersProvider: SchedulersProvider, private val mediaPipe: MediaPipe, private val bitmapToBase64Converter: BitmapToBase64Converter, ) : CoreGenerationRepository( @@ -33,7 +35,7 @@ internal class MediaPipeGenerationRepositoryImpl( override fun generateFromText(payload: TextToImagePayload): Single = mediaPipe .process(payload) - .subscribeOn(Schedulers.computation()) + .subscribeOn(schedulersProvider.singleThread.let(Schedulers::from)) .map(BitmapToBase64Converter::Input) .flatMap(bitmapToBase64Converter::invoke) .map(BitmapToBase64Converter.Output::base64ImageString) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt index e1c2c830..73a5b0eb 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt @@ -13,8 +13,6 @@ import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.mockk.every import io.mockk.mockk -import io.mockk.mockkConstructor -import io.mockk.mockkStatic import io.reactivex.rxjava3.core.BackpressureStrategy import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Flowable @@ -22,7 +20,6 @@ import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.subjects.BehaviorSubject import org.junit.Assert import org.junit.Test -import java.io.File class DownloadableModelLocalDataSourceTest { @@ -51,7 +48,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.PLAY every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -61,7 +58,7 @@ class DownloadableModelLocalDataSourceTest { val expected = mockLocalModelEntities.mapEntityToDomain() localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue { actual -> @@ -82,11 +79,11 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.PLAY every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(emptyList()) @@ -105,7 +102,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.FOSS every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -114,11 +111,11 @@ class DownloadableModelLocalDataSourceTest { val expected = buildList { addAll(mockLocalModelEntities.mapEntityToDomain()) - add(LocalAiModel.CUSTOM.copy(downloaded = true)) + add(LocalAiModel.CustomOnnx.copy(downloaded = true)) } localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue { actual -> @@ -139,14 +136,14 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.FOSS every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() - .assertValue(listOf(LocalAiModel.CUSTOM.copy(downloaded = true))) + .assertValue(listOf(LocalAiModel.CustomOnnx.copy(downloaded = true))) .await() .assertComplete() } @@ -158,7 +155,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.error(stubException) localDataSource - .getAll() + .getAllOnnx() .test() .assertError(stubException) .assertNoValues() @@ -173,7 +170,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.just(mockLocalModelEntity) every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -198,7 +195,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.just(mockLocalModelEntity) every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "5598" every { @@ -238,7 +235,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.just(mockLocalModelEntity) every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "5598" every { @@ -248,7 +245,7 @@ class DownloadableModelLocalDataSourceTest { val expected = mockLocalModelEntity.mapEntityToDomain().copy(selected = true) localDataSource - .getSelected() + .getSelectedOnnx() .test() .assertNoErrors() .assertValue(expected) @@ -259,11 +256,11 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get selected model, preference throws exception, expected error value`() { every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" localDataSource - .getSelected() + .getSelectedOnnx() .test() .assertError { t -> t is IllegalStateException && t.message == "No selected model." @@ -284,7 +281,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.PLAY every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -292,7 +289,7 @@ class DownloadableModelLocalDataSourceTest { } returns "/tmp/local" val stubObserver = localDataSource - .observeAll() + .observeAllOnnx() .test() stubLocalModels.onNext(emptyList()) @@ -319,7 +316,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.FOSS every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -327,14 +324,14 @@ class DownloadableModelLocalDataSourceTest { } returns "/tmp/local" val stubObserver = localDataSource - .observeAll() + .observeAllOnnx() .test() stubLocalModels.onNext(emptyList()) stubObserver .assertNoErrors() - .assertValueAt(0, listOf(LocalAiModel.CUSTOM.copy(downloaded = true))) + .assertValueAt(0, listOf(LocalAiModel.CustomOnnx.copy(downloaded = true))) stubLocalModels.onNext(mockLocalModelEntities) @@ -342,7 +339,7 @@ class DownloadableModelLocalDataSourceTest { .assertNoErrors() .assertValueAt(1, buildList { addAll(mockLocalModelEntities.mapEntityToDomain()) - add(LocalAiModel.CUSTOM.copy(downloaded = true)) + add(LocalAiModel.CustomOnnx.copy(downloaded = true)) }) } @@ -353,7 +350,7 @@ class DownloadableModelLocalDataSourceTest { } returns Flowable.error(stubException) localDataSource - .observeAll() + .observeAllOnnx() .test() .assertError(stubException) .assertNoValues() @@ -361,44 +358,6 @@ class DownloadableModelLocalDataSourceTest { .assertNotComplete() } - @Test - fun `given attempt to select model, preference changed, expected preference returns changed selected model id value`() { - every { - stubPreferenceManager.localModelId - } returns "" - - every { - stubPreferenceManager::localModelId.set(any()) - } returns Unit - - localDataSource - .select("5598") - .test() - .assertNoErrors() - .await() - .assertComplete() - - every { - stubPreferenceManager.localModelId - } returns "5598" - - Assert.assertEquals("5598", stubPreferenceManager.localModelId) - } - - @Test - fun `given attempt to select model, preference throws exception, expected error value`() { - every { - stubPreferenceManager::localModelId.set(any()) - } throws stubException - - localDataSource - .select("5598") - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - @Test fun `given attempt to save local model list, dao insert success, expected complete value`() { every { @@ -427,8 +386,6 @@ class DownloadableModelLocalDataSourceTest { .assertNotComplete() } - //-- - @Test fun `given attempt to delete file, delete operation success, expected complete value`() { every { @@ -456,15 +413,4 @@ class DownloadableModelLocalDataSourceTest { .await() .assertNotComplete() } - - @Test - fun `given attempt to check if CUSTOM model is downloaded, expected true`() { - localDataSource - .isDownloaded(LocalAiModel.CUSTOM.id) - .test() - .assertNoErrors() - .assertValue(true) - .await() - .assertComplete() - } } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt index 33b45cd9..7beaa4be 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity val mockLocalModelEntity = LocalModelEntity( id = "5598", + type = "onnx", name = "Best model in entire universe", size = "5598 Gb", sources = listOf("https://5598.is.my.favourite.com"), @@ -12,6 +13,7 @@ val mockLocalModelEntity = LocalModelEntity( val mockLocalModelEntities = listOf( LocalModelEntity( id = "1", + type = "onnx", name = "Model 1", size = "1 Gb", sources = listOf("https://example.com/1.php"), diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index b1d953ae..20b1f5b3 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -373,14 +373,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) .thenReturn("") - Assert.assertEquals("", preferenceManager.localModelId) + Assert.assertEquals("", preferenceManager.localOnnxModelId) whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) .thenReturn("key") - preferenceManager.localModelId = "key" + preferenceManager.localOnnxModelId = "key" - Assert.assertEquals("key", preferenceManager.localModelId) + Assert.assertEquals("key", preferenceManager.localOnnxModelId) } @Test @@ -388,14 +388,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any())) .thenReturn(false) - Assert.assertEquals(false, preferenceManager.localUseNNAPI) + Assert.assertEquals(false, preferenceManager.localOnnxUseNNAPI) whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any())) .thenReturn(true) - preferenceManager.localUseNNAPI = true + preferenceManager.localOnnxUseNNAPI = true - Assert.assertEquals(true, preferenceManager.localUseNNAPI) + Assert.assertEquals(true, preferenceManager.localOnnxUseNNAPI) preferenceManager .observe() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt index c97b0a9e..2b57a295 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt @@ -5,6 +5,7 @@ import com.nhaarman.mockitokotlin2.whenever import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.mocks.mockDownloadableModelsResponse +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import io.reactivex.rxjava3.core.Single import org.junit.Test @@ -22,10 +23,10 @@ class DownloadableModelRemoteDataSourceTest { @Test fun `given attempt to fetch models list, api returns data, expected valid domain models list`() { - whenever(stubApi.fetchDownloadableModels()) + whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.just(mockDownloadableModelsResponse)) - val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain() + val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) remoteDataSource .fetch() @@ -38,7 +39,7 @@ class DownloadableModelRemoteDataSourceTest { @Test fun `given attempt to fetch models list, api returns empty data, expected empty domain models list`() { - whenever(stubApi.fetchDownloadableModels()) + whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.just(emptyList())) remoteDataSource @@ -52,7 +53,7 @@ class DownloadableModelRemoteDataSourceTest { @Test fun `given attempt to fetch models list, api returns error, expected error value`() { - whenever(stubApi.fetchDownloadableModels()) + whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.error(stubException)) remoteDataSource diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt index 2daf1b5f..acfb8f74 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt @@ -34,7 +34,7 @@ class DownloadableModelRepositoryImplTest { @Before fun initialize() { every { - stubLocalDataSource.observeAll() + stubLocalDataSource.observeAllOnnx() } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) every { @@ -42,51 +42,6 @@ class DownloadableModelRepositoryImplTest { } returns stubDownloadState } - @Test - fun `given attempt to check if model downloaded, local data source returns true, expected true value`() { - every { - stubLocalDataSource.isDownloaded(any()) - } returns Single.just(true) - - repository - .isModelDownloaded("5598") - .test() - .assertNoErrors() - .assertValue(true) - .await() - .assertComplete() - } - - @Test - fun `given attempt to check if model downloaded, local data source returns false, expected false value`() { - every { - stubLocalDataSource.isDownloaded(any()) - } returns Single.just(false) - - repository - .isModelDownloaded("5598") - .test() - .assertNoErrors() - .assertValue(false) - .await() - .assertComplete() - } - - @Test - fun `given attempt to check if model downloaded, local data source throws exception, expected error value`() { - every { - stubLocalDataSource.isDownloaded(any()) - } returns Single.error(stubException) - - repository - .isModelDownloaded("5598") - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } - @Test fun `given attempt to delete model, local data source completes, expected complete value`() { every { @@ -115,34 +70,6 @@ class DownloadableModelRepositoryImplTest { .assertNotComplete() } - @Test - fun `given attempt to select model, local data source completes, expected complete value`() { - every { - stubLocalDataSource.select(any()) - } returns Completable.complete() - - repository - .select("5598") - .test() - .assertNoErrors() - .await() - .assertComplete() - } - - @Test - fun `given attempt to select model, local data source throws exception, expected error value`() { - every { - stubLocalDataSource.select(any()) - } returns Completable.error(stubException) - - repository - .select("5598") - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - @Test fun `given attempt to get all, remote returns list, save success, local query success, expected valid domain model list value`() { every { @@ -154,11 +81,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.complete() every { - stubLocalDataSource.getAll() + stubLocalDataSource.getAllOnnx() } returns Single.just(mockLocalAiModels) repository - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(mockLocalAiModels) @@ -177,11 +104,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.error(stubException) every { - stubLocalDataSource.getAll() + stubLocalDataSource.getAllOnnx() } returns Single.just(mockLocalAiModels) repository - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(mockLocalAiModels) @@ -200,11 +127,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.complete() every { - stubLocalDataSource.getAll() + stubLocalDataSource.getAllOnnx() } returns Single.just(mockLocalAiModels) repository - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(mockLocalAiModels) @@ -223,41 +150,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.complete() every { - stubLocalDataSource.getAll() - } returns Single.error(stubException) - - repository - .getAll() - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to get by id, local data source returns data, expected valid domain model value`() { - every { - stubLocalDataSource.getById(any()) - } returns Single.just(mockLocalAiModel) - - repository - .getById("5598") - .test() - .assertNoErrors() - .assertValue(mockLocalAiModel) - .await() - .assertComplete() - } - - @Test - fun `given attempt to get by id, local data source fails, expected error value`() { - every { - stubLocalDataSource.getById(any()) + stubLocalDataSource.getAllOnnx() } returns Single.error(stubException) repository - .getById("5598") + .getAllOnnx() .test() .assertError(stubException) .assertNoValues() @@ -267,7 +164,7 @@ class DownloadableModelRepositoryImplTest { @Test fun `given observe all models, local data source emits empty list, then another list, expected empty value, then valid domain models list value`() { - val stubObserver = repository.observeAll().test() + val stubObserver = repository.observeAllOnnx().test() stubLocalModels.onNext(emptyList()) @@ -284,7 +181,7 @@ class DownloadableModelRepositoryImplTest { @Test fun `given observe all models, local data source emits list, then changed list, expected valid domain models list value, then changed value`() { - val stubObserver = repository.observeAll().test() + val stubObserver = repository.observeAllOnnx().test() stubLocalModels.onNext(mockLocalAiModels) @@ -302,11 +199,11 @@ class DownloadableModelRepositoryImplTest { @Test fun `given observe all models, local data source throws exception, expected error value`() { every { - stubLocalDataSource.observeAll() + stubLocalDataSource.observeAllOnnx() } returns Flowable.error(stubException) repository - .observeAll() + .observeAllOnnx() .test() .assertError(stubException) .assertNoValues() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt index 75df885f..88cfd515 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt @@ -63,7 +63,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Before fun initialize() { every { - stubPreferenceManager::localDiffusionSchedulerThread.get() + stubPreferenceManager::localOnnxSchedulerThread.get() } returns SchedulersToken.COMPUTATION every { @@ -142,7 +142,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, no selected model, expected error value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.error(stubException) repository @@ -157,7 +157,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, has selected not downloaded model, expected IllegalStateException error value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.just(mockLocalAiModel.copy(downloaded = false)) every { @@ -182,7 +182,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, has selected downloaded model, local process success, expected valid domain model value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.just(mockLocalAiModel.copy(downloaded = true)) every { @@ -205,7 +205,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, has selected downloaded model, local process fails, expected error value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.just(mockLocalAiModel.copy(downloaded = true)) every { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt index b2320323..cb56c621 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt @@ -15,13 +15,12 @@ sealed interface DownloadableModelDataSource { } interface Local : DownloadableModelDataSource { - fun getAll(): Single> + fun getAllOnnx(): Single> + fun getAllMediaPipe(): Single> fun getById(id: String): Single - fun getSelected(): Single - fun observeAll(): Flowable> - fun select(id: String): Completable + fun getSelectedOnnx(): Single + fun observeAllOnnx(): Flowable> fun save(list: List): Completable - fun isDownloaded(id: String): Single fun delete(id: String): Completable } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index 535ee266..5bdaf74e 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -36,10 +36,12 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteAllGalleryUseCase import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteAllGalleryUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase @@ -157,9 +159,10 @@ internal val useCasesModule = module { factoryOf(::SaveLastResultToCacheUseCaseImpl) bind SaveLastResultToCacheUseCase::class factoryOf(::GetLastResultFromCacheUseCaseImpl) bind GetLastResultFromCacheUseCase::class factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class - factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class + factoryOf(::GetLocalOnnxModelsUseCaseImpl) bind GetLocalOnnxModelsUseCase::class + factoryOf(::GetLocalMediaPipeModelsUseCaseImpl) bind GetLocalMediaPipeModelsUseCase::class factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class - factoryOf(::ObserveLocalAiModelsUseCaseImpl) bind ObserveLocalAiModelsUseCase::class + factoryOf(::ObserveLocalOnnxModelsUseCaseImpl) bind ObserveLocalOnnxModelsUseCase::class factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index fcf04c19..40c5e10a 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -15,6 +15,8 @@ data class Configuration( val stabilityAiApiKey: String = "", val stabilityAiEngineId: String = "", val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None, - val localModelId: String = "", - val localModelPath: String = "", + val localOnnxModelId: String = "", + val localOnnxModelPath: String = "", + val localMediaPipeModelId: String = "", + val localMediaPipeModelPath: String = "", ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt index 734ad5ed..b7958512 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt @@ -2,15 +2,34 @@ package com.shifthackz.aisdv1.domain.entity data class LocalAiModel( val id: String, + val type: Type, val name: String, val size: String, val sources: List, val downloaded: Boolean = false, val selected: Boolean = false, ) { + enum class Type(val key: String) { + ONNX("onnx"), + MediaPipe("mediapipe"); + + companion object { + fun parse(value: String?) = entries.find { it.key == value } ?: ONNX + } + } + companion object { - val CUSTOM = LocalAiModel( + val CustomOnnx = LocalAiModel( id = "CUSTOM", + type = Type.ONNX, + name = "Custom", + size = "NaN", + sources = emptyList(), + ) + + val CustomMediaPipe = LocalAiModel( + id = "CUSTOM_MP", + type = Type.MediaPipe, name = "Custom", size = "NaN", sources = emptyList(), diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt index f2275b25..d27e0502 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt @@ -6,5 +6,4 @@ import io.reactivex.rxjava3.core.Single interface MediaPipe { fun process(payload: TextToImagePayload): Single - } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index 1ee8950b..42c32959 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -12,9 +12,10 @@ interface PreferenceManager { var swarmUiModel: String var demoMode: Boolean var developerMode: Boolean - var localDiffusionCustomModelPath: String - var localDiffusionAllowCancel: Boolean - var localDiffusionSchedulerThread: SchedulersToken + var localMediaPipeCustomModelPath: String + var localOnnxCustomModelPath: String + var localOnnxAllowCancel: Boolean + var localOnnxSchedulerThread: SchedulersToken var monitorConnectivity: Boolean var autoSaveAiResults: Boolean var saveToMediaStore: Boolean @@ -30,8 +31,9 @@ interface PreferenceManager { var stabilityAiEngineId: String var onBoardingComplete: Boolean var forceSetupAfterUpdate: Boolean - var localModelId: String - var localUseNNAPI: Boolean + var localOnnxModelId: String + var localOnnxUseNNAPI: Boolean + var localMediaPipeModelId: String var designUseSystemColorPalette: Boolean var designUseSystemDarkTheme: Boolean var designDarkTheme: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt index dcab5955..79efed66 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt @@ -8,11 +8,9 @@ import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single interface DownloadableModelRepository { - fun isModelDownloaded(id: String): Single fun download(id: String): Observable fun delete(id: String): Completable - fun getAll(): Single> - fun getById(id: String): Single - fun observeAll(): Flowable> - fun select(id: String): Completable + fun getAllOnnx(): Single> + fun getAllMediaPipe(): Single> + fun observeAllOnnx(): Flowable> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt new file mode 100644 index 00000000..cd93801f --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import io.reactivex.rxjava3.core.Single + +interface GetLocalMediaPipeModelsUseCase { + operator fun invoke(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt new file mode 100644 index 00000000..4da8eeba --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository + +internal class GetLocalMediaPipeModelsUseCaseImpl( + private val downloadableModelRepository: DownloadableModelRepository, + ) : GetLocalMediaPipeModelsUseCase { + + override fun invoke() = downloadableModelRepository.getAllMediaPipe() +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt similarity index 84% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt index efe71374..f2cce297 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt @@ -3,6 +3,6 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Single -interface GetLocalAiModelsUseCase { +interface GetLocalOnnxModelsUseCase { operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt similarity index 59% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt index 7bdbe7e7..444eccca 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt @@ -2,9 +2,9 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository -internal class GetLocalAiModelsUseCaseImpl( +internal class GetLocalOnnxModelsUseCaseImpl( private val downloadableModelRepository: DownloadableModelRepository, -) : GetLocalAiModelsUseCase { +) : GetLocalOnnxModelsUseCase { - override fun invoke() = downloadableModelRepository.getAll() + override fun invoke() = downloadableModelRepository.getAllOnnx() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt similarity index 83% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt index f79d31b1..021ebd1e 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt @@ -3,6 +3,6 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Flowable -interface ObserveLocalAiModelsUseCase { +interface ObserveLocalOnnxModelsUseCase { operator fun invoke(): Flowable> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt similarity index 70% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt index 5fae54e1..e5e290f8 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt @@ -2,11 +2,11 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository -internal class ObserveLocalAiModelsUseCaseImpl( +internal class ObserveLocalOnnxModelsUseCaseImpl( private val repository: DownloadableModelRepository, -) : ObserveLocalAiModelsUseCase { +) : ObserveLocalOnnxModelsUseCase { override fun invoke() = repository - .observeAll() + .observeAllOnnx() .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt index 5d89e78c..d2519425 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt @@ -12,7 +12,7 @@ internal class ConnectToLocalDiffusionUseCaseImpl( .map { originalConfiguration -> originalConfiguration.copy( source = ServerSource.LOCAL_MICROSOFT_ONNX, - localModelId = modelId, + localOnnxModelId = modelId, ) } .flatMapCompletable(setServerConfigurationUseCase::invoke) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt index a1f04a74..5c1027e6 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt @@ -3,5 +3,5 @@ package com.shifthackz.aisdv1.domain.usecase.settings import io.reactivex.rxjava3.core.Single interface ConnectToMediaPipeUseCase { - operator fun invoke(): Single> + operator fun invoke(modelId: String): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt index cf2d2a3d..a60bdc68 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt @@ -8,14 +8,14 @@ internal class ConnectToMediaPipeUseCaseImpl( private val setServerConfigurationUseCase: SetServerConfigurationUseCase, ) : ConnectToMediaPipeUseCase { - override fun invoke() = getConfigurationUseCase() + override fun invoke(modelId: String): Single> = getConfigurationUseCase() .map { originalConfiguration -> originalConfiguration.copy( source = ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + localMediaPipeModelId = modelId, ) } .flatMapCompletable(setServerConfigurationUseCase::invoke) .andThen(Single.just(Result.success(Unit))) .onErrorResumeNext { t -> Single.just(Result.failure(t)) } - } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index 4070a0a6..e88ebd63 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -24,8 +24,10 @@ internal class GetConfigurationUseCaseImpl( stabilityAiApiKey = preferenceManager.stabilityAiApiKey, stabilityAiEngineId = preferenceManager.stabilityAiEngineId, authCredentials = authorizationStore.getAuthorizationCredentials(), - localModelId = preferenceManager.localModelId, - localModelPath = preferenceManager.localDiffusionCustomModelPath, + localOnnxModelId = preferenceManager.localOnnxModelId, + localOnnxModelPath = preferenceManager.localOnnxCustomModelPath, + localMediaPipeModelId = preferenceManager.localMediaPipeModelId, + localMediaPipeModelPath = preferenceManager.localMediaPipeCustomModelPath, ) ) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index 8d03619c..aab29abc 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -24,7 +24,9 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.huggingFaceModel = configuration.huggingFaceModel preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId - preferenceManager.localModelId = configuration.localModelId - preferenceManager.localDiffusionCustomModelPath = configuration.localModelPath + preferenceManager.localOnnxModelId = configuration.localOnnxModelId + preferenceManager.localOnnxCustomModelPath = configuration.localOnnxModelPath + preferenceManager.localMediaPipeModelId = configuration.localMediaPipeModelId + preferenceManager.localMediaPipeCustomModelPath = configuration.localMediaPipeModelPath } } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt index 73503575..da1d486a 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -12,6 +12,6 @@ val mockConfiguration = Configuration( huggingFaceModel = "5598", stabilityAiApiKey = "5598", stabilityAiEngineId = "5598", - localModelId = "5598", - localModelPath = "/storage/emulated/0/5598", + localOnnxModelId = "5598", + localOnnxModelPath = "/storage/emulated/0/5598", ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt index 69fcf3f0..d3c53d8a 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt @@ -3,7 +3,7 @@ package com.shifthackz.aisdv1.domain.mocks import com.shifthackz.aisdv1.domain.entity.LocalAiModel val mockLocalAiModels = listOf( - LocalAiModel.CUSTOM, + LocalAiModel.CustomOnnx, LocalAiModel( id = "1", name = "Model 1", diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt similarity index 84% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt index 679ca297..d154c163 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt @@ -7,15 +7,15 @@ import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository import io.reactivex.rxjava3.core.Single import org.junit.Test -class GetLocalAiModelsUseCaseImplTest { +class GetLocalOnnxModelsUseCaseImplTest { private val stubRepository = mock() - private val useCase = GetLocalAiModelsUseCaseImpl(stubRepository) + private val useCase = GetLocalOnnxModelsUseCaseImpl(stubRepository) @Test fun `given repository returned models list, expected valid models list value`() { - whenever(stubRepository.getAll()) + whenever(stubRepository.getAllOnnx()) .thenReturn(Single.just(mockLocalAiModels)) useCase() @@ -28,7 +28,7 @@ class GetLocalAiModelsUseCaseImplTest { @Test fun `given repository returned empty models list, expected empty models list value`() { - whenever(stubRepository.getAll()) + whenever(stubRepository.getAllOnnx()) .thenReturn(Single.just(emptyList())) useCase() @@ -43,7 +43,7 @@ class GetLocalAiModelsUseCaseImplTest { fun `given repository thrown exception, expected error value`() { val stubException = Throwable("Unable to collect local models.") - whenever(stubRepository.getAll()) + whenever(stubRepository.getAllOnnx()) .thenReturn(Single.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt similarity index 91% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt index 00fde27a..a61684b1 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt @@ -11,16 +11,16 @@ import io.reactivex.rxjava3.subjects.BehaviorSubject import org.junit.Before import org.junit.Test -class ObserveLocalAiModelsUseCaseImplTest { +class ObserveLocalOnnxModelsUseCaseImplTest { private val stubLocalModels = BehaviorSubject.create>() private val stubRepository = mock() - private val useCase = ObserveLocalAiModelsUseCaseImpl(stubRepository) + private val useCase = ObserveLocalOnnxModelsUseCaseImpl(stubRepository) @Before fun initialize() { - whenever(stubRepository.observeAll()) + whenever(stubRepository.observeAllOnnx()) .thenReturn(stubLocalModels.toFlowable(BackpressureStrategy.LATEST)) } @@ -68,7 +68,7 @@ class ObserveLocalAiModelsUseCaseImplTest { .assertNoErrors() .assertValueAt(0, mockLocalAiModels) - val changedLocalAiModels = listOf(LocalAiModel.CUSTOM) + val changedLocalAiModels = listOf(LocalAiModel.CustomOnnx) stubLocalModels.onNext(changedLocalAiModels) stubObserver @@ -97,7 +97,7 @@ class ObserveLocalAiModelsUseCaseImplTest { fun `given observer terminates with unexpected error, expected receive error value`() { val stubException = Throwable("Unexpected Flowable termination.") - whenever(stubRepository.observeAll()) + whenever(stubRepository.observeAllOnnx()) .thenReturn(Flowable.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index 55290a49..c8bb1cd2 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -69,12 +69,12 @@ class GetConfigurationUseCaseImplTest { } returns mockConfiguration.stabilityAiEngineId every { - stubPreferenceManager::localModelId.get() - } returns mockConfiguration.localModelId + stubPreferenceManager::localOnnxModelId.get() + } returns mockConfiguration.localOnnxModelId every { - stubPreferenceManager::localDiffusionCustomModelPath.get() - } returns mockConfiguration.localModelPath + stubPreferenceManager::localOnnxCustomModelPath.get() + } returns mockConfiguration.localOnnxModelPath useCase .invoke() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index 62a74c34..ed48a9d5 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -68,11 +68,11 @@ class SetServerConfigurationUseCaseImplTest { } returns Unit every { - stubPreferenceManager::localModelId.set(any()) + stubPreferenceManager::localOnnxModelId.set(any()) } returns Unit every { - stubPreferenceManager::localDiffusionCustomModelPath.set(any()) + stubPreferenceManager::localOnnxCustomModelPath.set(any()) } returns Unit useCase diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt index 7cb6df87..1b22e6d9 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt @@ -11,8 +11,8 @@ fun modelPathPrefix( localModelIdProvider: LocalModelIdProvider, ): String { val modelId = localModelIdProvider.get() - return if (modelId == LocalAiModel.CUSTOM.id) { - preferenceManager.localDiffusionCustomModelPath + return if (modelId == LocalAiModel.CustomOnnx.id) { + preferenceManager.localOnnxCustomModelPath } else { "${fileProviderDescriptor.localModelDirPath}/${modelId}" } diff --git a/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt index 0defdf78..3be83e5e 100644 --- a/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt +++ b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -1,13 +1,13 @@ package com.shifthackz.aisdv1.feature.mediapipe -import android.content.Context import android.graphics.Bitmap -import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe import io.reactivex.rxjava3.core.Single internal class MediaPipeImpl : MediaPipe { - override fun process(payload: TextToImagePayload): Single = Single.error(Throwable("null")) + override fun process(payload: TextToImagePayload): Single { + return Single.error(IllegalStateException("Google AI MediaPipe is not supported on FOSS build.")) + } } diff --git a/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt index c75caad8..a6e7e98a 100644 --- a/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt +++ b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -6,29 +6,33 @@ import com.google.mediapipe.framework.image.BitmapExtractor import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.feature.mediapipe.extensions.modelPath import io.reactivex.rxjava3.core.Single internal class MediaPipeImpl( private val context: Context, + private val preferenceManager: PreferenceManager, private val fileProviderDescriptor: FileProviderDescriptor, -): MediaPipe { +) : MediaPipe { private var imageGenerator: ImageGenerator? = null override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> try { initialize() - println("Generating...") + debugLog("Generating...") val result = imageGenerator?.generate( payload.prompt, payload.samplingSteps, payload.seed.toIntOrNull() ?: 0, ) - println("Extracting bitmap...") + debugLog("Extracting bitmap...") val bitmap = BitmapExtractor.extract(result?.generatedImage()) - println("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") + debugLog("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") close() if (!emitter.isDisposed) emitter.onSuccess(bitmap) } catch (e: Exception) { @@ -38,20 +42,22 @@ internal class MediaPipeImpl( } private fun initialize(): ImageGenerator { + val path = modelPath(preferenceManager, fileProviderDescriptor) + val options = ImageGeneratorOptions.builder() - .setImageGeneratorModelDirectory(fileProviderDescriptor.mediaPipeDirPath) + .setImageGeneratorModelDirectory(path) .build() val generator = ImageGenerator.createFromOptions(context, options) imageGenerator = generator - println("Initialized successfully! Path: ${fileProviderDescriptor.mediaPipeDirPath}") + debugLog("Initialized successfully! Path: $path") return generator } private fun close() = runCatching { - println("Closing...") + debugLog("Closing...") imageGenerator?.close() imageGenerator = null - println("Session closed!") + debugLog("Session closed!") } } diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt new file mode 100644 index 00000000..91e49404 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt @@ -0,0 +1,17 @@ +package com.shifthackz.aisdv1.feature.mediapipe.extensions + +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager + +fun modelPath( + preferenceManager: PreferenceManager, + fileProviderDescriptor: FileProviderDescriptor, +): String { + val modelId = preferenceManager.localMediaPipeModelId + return if (modelId == LocalAiModel.CustomMediaPipe.id) { + preferenceManager.localMediaPipeCustomModelPath + } else { + "${fileProviderDescriptor.localModelDirPath}/${modelId}" + } +} diff --git a/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt new file mode 100644 index 00000000..a6e7e98a --- /dev/null +++ b/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -0,0 +1,63 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import android.content.Context +import android.graphics.Bitmap +import com.google.mediapipe.framework.image.BitmapExtractor +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.core.common.log.debugLog +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.feature.mediapipe.extensions.modelPath +import io.reactivex.rxjava3.core.Single + +internal class MediaPipeImpl( + private val context: Context, + private val preferenceManager: PreferenceManager, + private val fileProviderDescriptor: FileProviderDescriptor, +) : MediaPipe { + + private var imageGenerator: ImageGenerator? = null + + override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> + try { + initialize() + debugLog("Generating...") + val result = imageGenerator?.generate( + payload.prompt, + payload.samplingSteps, + payload.seed.toIntOrNull() ?: 0, + ) + debugLog("Extracting bitmap...") + val bitmap = BitmapExtractor.extract(result?.generatedImage()) + debugLog("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") + close() + if (!emitter.isDisposed) emitter.onSuccess(bitmap) + } catch (e: Exception) { + close() + if (!emitter.isDisposed) emitter.onError(e) + } + } + + private fun initialize(): ImageGenerator { + val path = modelPath(preferenceManager, fileProviderDescriptor) + + val options = ImageGeneratorOptions.builder() + .setImageGeneratorModelDirectory(path) + .build() + + val generator = ImageGenerator.createFromOptions(context, options) + imageGenerator = generator + debugLog("Initialized successfully! Path: $path") + return generator + } + + private fun close() = runCatching { + debugLog("Closing...") + imageGenerator?.close() + imageGenerator = null + debugLog("Session closed!") + } +} diff --git a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt index ea7504b4..b77330ee 100644 --- a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt +++ b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt @@ -93,7 +93,7 @@ internal abstract class CoreGenerationWorker( body = subTitle, silent = true, progress = status.current to status.total, - canCancel = preferenceManager.localDiffusionAllowCancel, + canCancel = preferenceManager.localOnnxAllowCancel, ) } } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt index 90adc17b..78a35382 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt @@ -11,7 +11,9 @@ import java.io.File interface DownloadableModelsApi { - fun fetchDownloadableModels(): Single> + fun fetchOnnxModels(): Single> + + fun fetchMediaPipeModels(): Single> fun downloadModel( remoteUrl: String, @@ -23,7 +25,10 @@ interface DownloadableModelsApi { interface RawApi { @GET("/models.json") - fun fetchDownloadableModels(): Single> + fun fetchOnnxModels(): Single> + + @GET("/mediapipe.json") + fun fetchMediaPipeModels(): Single> @Streaming @GET diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt index 88ddc5ef..04fd71c1 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.network.api.sdai import com.shifthackz.aisdv1.network.extensions.saveFile +import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File @@ -9,7 +10,9 @@ internal class DownloadableModelsApiImpl( private val rawApi: DownloadableModelsApi.RawApi, ) : DownloadableModelsApi { - override fun fetchDownloadableModels() = rawApi.fetchDownloadableModels() + override fun fetchOnnxModels() = rawApi.fetchOnnxModels() + + override fun fetchMediaPipeModels() = rawApi.fetchMediaPipeModels() override fun downloadModel( remoteUrl: String, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt index a6060cb9..a2dfe010 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt @@ -72,7 +72,8 @@ val viewModelModule = module { launchSource = launchSource, dispatchersProvider = get(), getConfigurationUseCase = get(), - getLocalAiModelsUseCase = get(), + getLocalOnnxModelsUseCase = get(), + getLocalMediaPipeModelsUseCase = get(), fetchAndGetHuggingFaceModelsUseCase = get(), urlValidator = get(), stringValidator = get(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt index 0d2fd5df..c6b5f7de 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt @@ -64,7 +64,7 @@ class DebugMenuViewModel( DebugMenuIntent.ViewLogs -> mainRouter.navigateToLogger() DebugMenuIntent.AllowLocalDiffusionCancel -> { - preferenceManager.localDiffusionAllowCancel = !currentState.localDiffusionAllowCancel + preferenceManager.localOnnxAllowCancel = !currentState.localDiffusionAllowCancel } DebugMenuIntent.LocalDiffusionScheduler.Request -> updateState { @@ -72,7 +72,7 @@ class DebugMenuViewModel( } is DebugMenuIntent.LocalDiffusionScheduler.Confirm -> { - preferenceManager.localDiffusionSchedulerThread = intent.token + preferenceManager.localOnnxSchedulerThread = intent.token } DebugMenuIntent.DismissModal -> updateState { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt index a669acc0..66a2c97a 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt @@ -142,7 +142,7 @@ class SettingsViewModel( } is SettingsIntent.UpdateFlag.NNAPI -> { - preferenceManager.localUseNNAPI = intent.flag + preferenceManager.localOnnxUseNNAPI = intent.flag } is SettingsIntent.UpdateFlag.TaggedInput -> { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index 5635568b..f4e5b0f2 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -153,7 +153,11 @@ fun ServerSetupScreenContent( onClick = { processIntent(ServerSetupIntent.MainButtonClick) }, enabled = when (state.step) { ServerSetupState.Step.CONFIGURE -> when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> state.localModels.any { + ServerSource.LOCAL_MICROSOFT_ONNX -> state.localOnnxModels.any { + it.downloaded && it.selected + } + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.localMediaPipeModels.any { it.downloaded && it.selected } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 553e7bde..639b3c63 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -33,9 +33,12 @@ data class ServerSetupState( val password: String = "", val huggingFaceModels: List = emptyList(), val huggingFaceModel: String = "", - val localModels: List = emptyList(), - val localCustomModel: Boolean = false, - val localCustomModelPath: String = "", + val localOnnxModels: List = emptyList(), + val localOnnxCustomModel: Boolean = false, + val localOnnxCustomModelPath: String = "", + val localMediaPipeModels: List = emptyList(), + val localMediaPipeCustomModel: Boolean = false, + val localMediaPipeCustomModelPath: String = "", val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, val swarmUiUrlValidationError: UiText? = null, @@ -45,9 +48,31 @@ data class ServerSetupState( val huggingFaceApiKeyValidationError: UiText? = null, val openAiApiKeyValidationError: UiText? = null, val stabilityAiApiKeyValidationError: UiText? = null, - val localCustomModelPathValidationError: UiText? = null, + val localCustomOnnxPathValidationError: UiText? = null, + val localCustomMediaPipePathValidationError: UiText? = null, ) : MviState, KoinComponent { + val localCustomModel: Boolean + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localOnnxCustomModel + } else { + localMediaPipeCustomModel + } + + val localModels: List + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localOnnxModels + } else { + localMediaPipeModels + } + + val localCustomModelPathValidationError: UiText? + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localCustomOnnxPathValidationError + } else { + localCustomMediaPipePathValidationError + } + val demoModeUrl: String get() { val linksProvider: LinksProvider by inject() diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 1aa89611..68f3c47b 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider +import com.shifthackz.aisdv1.core.common.model.Quadruple import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.model.asUiText @@ -21,14 +22,16 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.presentation.model.LaunchSource import com.shifthackz.aisdv1.presentation.model.Modal import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes -import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomModelSwitchState +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomMediaPipeSwitchState +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomOnnxSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants @@ -40,7 +43,8 @@ class ServerSetupViewModel( launchSource: LaunchSource, dispatchersProvider: DispatchersProvider, getConfigurationUseCase: GetConfigurationUseCase, - getLocalAiModelsUseCase: GetLocalAiModelsUseCase, + getLocalOnnxModelsUseCase: GetLocalOnnxModelsUseCase, + getLocalMediaPipeModelsUseCase: GetLocalMediaPipeModelsUseCase, fetchAndGetHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, private val urlValidator: UrlValidator, private val stringValidator: CommonStringValidator, @@ -78,12 +82,13 @@ class ServerSetupViewModel( init { !Single.zip( getConfigurationUseCase(), - getLocalAiModelsUseCase(), + getLocalOnnxModelsUseCase(), + getLocalMediaPipeModelsUseCase(), fetchAndGetHuggingFaceModelsUseCase(), - ::Triple, + ::Quadruple, ) .subscribeOnMainThread(schedulersProvider) - .subscribeBy(::errorLog) { (configuration, localModels, hfModels) -> + .subscribeBy(::errorLog) { (configuration, onnxModels, mpModels, hfModels) -> updateState { state -> state.copy( huggingFaceModels = hfModels.map(HuggingFaceModel::alias), @@ -91,9 +96,12 @@ class ServerSetupViewModel( huggingFaceApiKey = configuration.huggingFaceApiKey, openAiApiKey = configuration.openAiApiKey, stabilityAiApiKey = configuration.stabilityAiApiKey, - localModels = localModels.mapToUi(), - localCustomModel = localModels.mapLocalCustomModelSwitchState(), - localCustomModelPath = configuration.localModelPath, + localOnnxModels = onnxModels.mapToUi(), + localOnnxCustomModel = onnxModels.mapLocalCustomOnnxSwitchState(), + localOnnxCustomModelPath = configuration.localOnnxModelPath, + localMediaPipeModels = mpModels.mapToUi(), + localMediaPipeCustomModel = mpModels.mapLocalCustomMediaPipeSwitchState(), + localMediaPipeCustomModelPath = configuration.localMediaPipeModelPath, mode = configuration.source, allowedModes = buildInfoProvider.allowedModes, demoMode = configuration.demoMode, @@ -115,15 +123,28 @@ class ServerSetupViewModel( } override fun processIntent(intent: ServerSetupIntent) = when (intent) { - is ServerSetupIntent.AllowLocalCustomModel -> updateState { - it.copy( - localCustomModel = intent.allow, - localModels = currentState.localModels.withNewState( - currentState.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }?.copy( - selected = intent.allow, + is ServerSetupIntent.AllowLocalCustomModel -> updateState { state -> + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> state.copy( + localOnnxCustomModel = intent.allow, + localOnnxModels = state.localOnnxModels.withNewState( + state.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }?.copy( + selected = intent.allow, + ), ), - ), - ) + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.copy( + localMediaPipeCustomModel = intent.allow, + localMediaPipeModels = state.localMediaPipeModels.withNewState( + state.localMediaPipeModels.find { m -> m.id == LocalAiModel.CustomMediaPipe.id}?.copy( + selected = intent.allow + ) + ) + ) + + else -> state + } } ServerSetupIntent.DismissDialog -> setScreenModal(Modal.None) @@ -136,7 +157,7 @@ class ServerSetupViewModel( .subscribeBy(::errorLog) it.copy( screenModal = Modal.None, - localModels = currentState.localModels.withNewState( + localOnnxModels = currentState.localOnnxModels.withNewState( intent.model.copy( downloadState = DownloadState.Unknown, downloaded = false, @@ -146,15 +167,22 @@ class ServerSetupViewModel( } is ServerSetupIntent.SelectLocalModel -> { - if (currentState.localModels.any { it.downloadState is DownloadState.Downloading }) { - Unit - } - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - intent.model.copy(selected = true), - ), - ) + updateState { state -> + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> state.copy( + localOnnxModels = state.localOnnxModels.withNewState( + intent.model.copy(selected = true), + ), + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.copy( + localMediaPipeModels = state.localMediaPipeModels.withNewState( + intent.model.copy(selected = true), + ), + ) + + else -> state + } } } @@ -241,11 +269,20 @@ class ServerSetupViewModel( ServerSetupIntent.ConnectToLocalHost -> connectToServer() - is ServerSetupIntent.SelectLocalModelPath -> updateState { - it.copy( - localCustomModelPath = intent.value, - localCustomModelPathValidationError = null, - ) + is ServerSetupIntent.SelectLocalModelPath -> updateState { state -> + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> state.copy( + localOnnxCustomModelPath = intent.value, + localCustomOnnxPathValidationError = null, + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.copy( + localMediaPipeCustomModelPath = intent.value, + localCustomMediaPipePathValidationError = null, + ) + + else -> state + } } } @@ -298,22 +335,28 @@ class ServerSetupViewModel( } } - ServerSource.LOCAL_MICROSOFT_ONNX -> { - if (currentState.localCustomModel) { - val validation = filePathValidator(currentState.localCustomModelPath) + ServerSource.LOCAL_MICROSOFT_ONNX -> if (currentState.localOnnxCustomModel) { + val validation = filePathValidator(currentState.localOnnxCustomModelPath) + updateState { + it.copy(localCustomOnnxPathValidationError = validation.mapToUi()) + } + validation.isValid + } else { + currentState.localOnnxModels.find { it.selected && it.downloaded } != null + } + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> when { + buildInfoProvider.type == BuildType.FOSS -> false + currentState.localMediaPipeCustomModel -> { + val validation = filePathValidator(currentState.localMediaPipeCustomModelPath) updateState { - it.copy(localCustomModelPathValidationError = validation.mapToUi()) + it.copy(localCustomMediaPipePathValidationError = validation.mapToUi()) } validation.isValid - } else { - currentState.localModels.find { it.selected && it.downloaded } != null } - } - - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> if (buildInfoProvider.type == BuildType.FOSS) { - false - } else { - true + else -> { + currentState.localMediaPipeModels.find { it.selected && it.downloaded } != null + } } ServerSource.HUGGING_FACE -> { @@ -417,13 +460,15 @@ class ServerSetupViewModel( } private fun connectToLocalDiffusion(): Single> { - preferenceManager.localDiffusionCustomModelPath = currentState.localCustomModelPath - val localModelId = currentState.localModels.find { it.selected }?.id ?: "" + preferenceManager.localOnnxCustomModelPath = currentState.localOnnxCustomModelPath + val localModelId = currentState.localOnnxModels.find { it.selected }?.id ?: "" return setupConnectionInterActor.connectToLocal(localModelId) } private fun connectToMediaPipe(): Single> { - return setupConnectionInterActor.connectToMediaPipe() + preferenceManager.localMediaPipeCustomModelPath = currentState.localMediaPipeCustomModelPath + val localModelId = currentState.localMediaPipeModels.find { it.selected }?.id ?: "" + return setupConnectionInterActor.connectToMediaPipe(localModelId) } private fun localModelDownloadClickReducer(localModel: ServerSetupState.LocalModel) { @@ -438,12 +483,26 @@ class ServerSetupViewModel( !deleteModelUseCase(localModel.id) .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - localModel.copy(downloadState = DownloadState.Unknown), - ), - ) + updateState { state -> + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> { + state.copy( + localOnnxModels = state.localOnnxModels.withNewState( + localModel.copy(downloadState = DownloadState.Unknown), + ), + ) + } + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { + state.copy( + localMediaPipeModels = state.localMediaPipeModels.withNewState( + localModel.copy(downloadState = DownloadState.Unknown), + ), + ) + } + + else -> state + } } } // User deletes local model @@ -452,14 +511,30 @@ class ServerSetupViewModel( } // User requested new download operation else -> { - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - localModel.copy( - downloadState = DownloadState.Downloading(), - ), - ), - ) + updateState { state -> + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> { + state.copy( + localOnnxModels = state.localOnnxModels.withNewState( + localModel.copy( + downloadState = DownloadState.Downloading(), + ), + ), + ) + } + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { + state.copy( + localMediaPipeModels = state.localMediaPipeModels.withNewState( + localModel.copy( + downloadState = DownloadState.Downloading(), + ), + ), + ) + } + + else -> state + } } !downloadModelUseCase(localModel.id) .distinctUntilChanged() @@ -469,13 +544,18 @@ class ServerSetupViewModel( onError = { t -> errorLog(t) val message = t.localizedMessage ?: "Error" - updateState { - it.copy( - localModels = currentState.localModels.withNewState( + updateState { state -> + state.copy( + localOnnxModels = state.localOnnxModels.withNewState( localModel.copy( downloadState = DownloadState.Error(t), ), ), + localMediaPipeModels = state.localMediaPipeModels.withNewState( + localModel.copy( + downloadState = DownloadState.Error(t), + ), + ) ) } setScreenModal(Modal.Error(message.asUiText())) @@ -484,7 +564,13 @@ class ServerSetupViewModel( updateState { when (downloadState) { is DownloadState.Complete -> it.copy( - localModels = it.localModels.withNewState( + localOnnxModels = it.localOnnxModels.withNewState( + localModel.copy( + downloadState = downloadState, + downloaded = true, + ), + ), + localMediaPipeModels = it.localMediaPipeModels.withNewState( localModel.copy( downloadState = downloadState, downloaded = true, @@ -493,7 +579,10 @@ class ServerSetupViewModel( ) else -> it.copy( - localModels = it.localModels.withNewState( + localOnnxModels = it.localOnnxModels.withNewState( + localModel.copy(downloadState = downloadState), + ), + localMediaPipeModels = it.localMediaPipeModels.withNewState( localModel.copy(downloadState = downloadState), ), ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt index 6ebfed22..7a8b65ce 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt @@ -54,6 +54,7 @@ import com.shifthackz.aisdv1.core.extensions.getRealPath import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupScreenTags.CUSTOM_MODEL_SWITCH import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState @@ -91,7 +92,7 @@ fun LocalDiffusionForm( val icon = when (model.downloadState) { is DownloadState.Downloading -> Icons.Outlined.FileDownload else -> when { - model.id == LocalAiModel.CUSTOM.id -> Icons.Outlined.Landslide + model.id == LocalAiModel.CustomOnnx.id -> Icons.Outlined.Landslide model.downloaded -> Icons.Outlined.FileDownloadDone else -> Icons.Outlined.FileDownloadOff } @@ -113,14 +114,17 @@ fun LocalDiffusionForm( overflow = TextOverflow.Ellipsis, maxLines = 2 ) - if (model.id != LocalAiModel.CUSTOM.id) { + if (model.id != LocalAiModel.CustomOnnx.id) { Text( text = model.size, maxLines = 1 ) } } - if (model.id != LocalAiModel.CUSTOM.id) { + // Do not display action button for custom model + if (model.id != LocalAiModel.CustomOnnx.id + && model.id != LocalAiModel.CustomMediaPipe.id + ) { Button( modifier = Modifier.padding(end = 8.dp), onClick = { processIntent(ServerSetupIntent.LocalModel.ClickReduce(model)) }, @@ -142,7 +146,9 @@ fun LocalDiffusionForm( } } } - if (model.id == LocalAiModel.CUSTOM.id) { + if (model.id == LocalAiModel.CustomOnnx.id + || model.id == LocalAiModel.CustomMediaPipe.id + ) { Column( modifier = Modifier.padding(8.dp), ) { @@ -163,7 +169,7 @@ fun LocalDiffusionForm( val folderStyle = MaterialTheme.typography.bodySmall Text( modifier = Modifier.padding(start = 12.dp), - text = state.localCustomModelPath, + text = state.localOnnxCustomModelPath, style = folderStyle, ) @@ -258,17 +264,29 @@ fun LocalDiffusionForm( modifier = Modifier .fillMaxWidth() .padding(top = 32.dp, bottom = 8.dp), - text = stringResource(id = LocalizationR.string.hint_local_diffusion_title), + text = stringResource( + id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + LocalizationR.string.hint_local_diffusion_title + } else { + LocalizationR.string.hint_mediapipe_title + }, + ), style = MaterialTheme.typography.bodyLarge, textAlign = TextAlign.Center, fontWeight = FontWeight.Bold, ) Text( modifier = Modifier.padding(top = 16.dp, bottom = 16.dp), - text = stringResource(id = LocalizationR.string.hint_local_diffusion_sub_title), + text = stringResource( + id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + LocalizationR.string.hint_local_diffusion_sub_title + } else { + LocalizationR.string.hint_mediapipe_sub_title + }, + ), style = MaterialTheme.typography.bodyMedium, ) - if (buildInfoProvider.type == BuildType.FOSS) { + if (buildInfoProvider.type != BuildType.PLAY) { Row( verticalAlignment = Alignment.CenterVertically, ) { @@ -285,7 +303,7 @@ fun LocalDiffusionForm( ) } } - if (state.localCustomModel && buildInfoProvider.type == BuildType.FOSS) { + if (state.localCustomModel && buildInfoProvider.type != BuildType.PLAY) { Text( modifier = Modifier .align(Alignment.CenterHorizontally) @@ -341,7 +359,7 @@ fun LocalDiffusionForm( modifier = Modifier .fillMaxWidth() .padding(top = 14.dp), - value = state.localCustomModelPath, + value = state.localOnnxCustomModelPath, onValueChange = { processIntent(ServerSetupIntent.SelectLocalModelPath(it)) }, enabled = true, singleLine = true, @@ -387,7 +405,7 @@ fun LocalDiffusionForm( } state.localModels .filter { - val customPredicate = it.id == LocalAiModel.CUSTOM.id + val customPredicate = it.id == LocalAiModel.CustomOnnx.id || it.id == LocalAiModel.CustomMediaPipe.id if (state.localCustomModel) customPredicate else !customPredicate } .forEach { localModel -> modelItemUi(localModel) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt new file mode 100644 index 00000000..3bb26b1d --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.forms + +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState + +@Composable +fun MediaPipeForm( + modifier: Modifier = Modifier, + state: ServerSetupState, + buildInfoProvider: BuildInfoProvider = BuildInfoProvider.stub, + processIntent: (ServerSetupIntent) -> Unit = {}, +) { + when (buildInfoProvider.type) { + BuildType.FOSS -> { + + } + + else -> LocalDiffusionForm( + modifier = modifier, + state = state, + buildInfoProvider = buildInfoProvider, + processIntent = processIntent, + ) + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt index f3a9ddc6..3cd1e443 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt @@ -5,8 +5,11 @@ import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState fun List.mapToUi(): List = map(LocalAiModel::mapToUi) -fun List.mapLocalCustomModelSwitchState(): Boolean = - find { it.selected && it.id == LocalAiModel.CUSTOM.id } != null +fun List.mapLocalCustomOnnxSwitchState(): Boolean = + find { it.selected && it.id == LocalAiModel.CustomOnnx.id } != null + +fun List.mapLocalCustomMediaPipeSwitchState(): Boolean = + find { it.selected && it.id == LocalAiModel.CustomMediaPipe.id } != null fun LocalAiModel.mapToUi(): ServerSetupState.LocalModel = with(this) { ServerSetupState.LocalModel( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index eb8f59d2..d25e8237 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.forms.Automatic1111Form import com.shifthackz.aisdv1.presentation.screen.setup.forms.HordeForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.HuggingFaceForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.LocalDiffusionForm +import com.shifthackz.aisdv1.presentation.screen.setup.forms.MediaPipeForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.OpenAiForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.StabilityAiForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.SwarmUiForm @@ -59,7 +60,11 @@ fun ConfigurationStep( processIntent = processIntent, ) - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> MediaPipeForm( + state = state, + buildInfoProvider = buildInfoProvider, + processIntent = processIntent, + ) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index 5b5d96a3..fc30f692 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -68,7 +68,7 @@ class TextToImageViewModel( private val progressModal: Modal get() { if (currentState.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - return Modal.Generating(canCancel = preferenceManager.localDiffusionAllowCancel) + return Modal.Generating(canCancel = preferenceManager.localOnnxAllowCancel) } return Modal.Communicating() } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index a1f0a0d6..9bb71002 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -11,7 +11,7 @@ import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase @@ -25,7 +25,7 @@ import io.reactivex.rxjava3.kotlin.subscribeBy class EngineSelectionViewModel( dispatchersProvider: DispatchersProvider, fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase, - observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase, + observeLocalOnnxModelsUseCase: ObserveLocalOnnxModelsUseCase, fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, private val preferenceManager: PreferenceManager, @@ -61,8 +61,8 @@ class EngineSelectionViewModel( .onErrorReturn { emptyList() } .toFlowable() - val localAiModels = observeLocalAiModelsUseCase() - .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CUSTOM.id } } + val localAiModels = observeLocalOnnxModelsUseCase() + .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CustomOnnx.id } } .onErrorReturn { emptyList() } !Flowable.combineLatest( @@ -94,7 +94,7 @@ class EngineSelectionViewModel( stEngines = stEngines.map { it.id }, selectedStEngine = config.stabilityAiEngineId, localAiModels = localModels, - selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localModelId }?.id + selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localOnnxModelId }?.id ?: state.selectedLocalAiModelId ) } @@ -131,7 +131,7 @@ class EngineSelectionViewModel( ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value - ServerSource.LOCAL_MICROSOFT_ONNX -> preferenceManager.localModelId = intent.value + ServerSource.LOCAL_MICROSOFT_ONNX -> preferenceManager.localOnnxModelId = intent.value else -> Unit } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt index 8b9eb04a..cdf190e9 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt @@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState val mockLocalAiModels = listOf( - LocalAiModel.CUSTOM, + LocalAiModel.CustomOnnx, LocalAiModel( id = "1", name = "Model 1", diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt index 05e9a78f..d9453b19 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt @@ -237,13 +237,13 @@ class SettingsViewModelTest : CoreViewModelTest() { @Test fun `given received UpdateFlag NNAPI intent, expected localUseNNAPI preference updated`() { every { - stubPreferenceManager::localUseNNAPI.set(any()) + stubPreferenceManager::localOnnxUseNNAPI.set(any()) } returns Unit viewModel.processIntent(SettingsIntent.UpdateFlag.NNAPI(true)) verify { - stubPreferenceManager::localUseNNAPI.set(true) + stubPreferenceManager::localOnnxUseNNAPI.set(true) } } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt index b2485734..27eb75b1 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt @@ -126,7 +126,7 @@ class ServerSetupScreenTest : CoreComposeTest { it.copy( step = ServerSetupState.Step.CONFIGURE, mode = ServerSource.LOCAL_MICROSOFT_ONNX, - localModels = mockLocalAiModels.mapToUi() + localOnnxModels = mockLocalAiModels.mapToUi() ) } val setupButton = onNodeWithTestTag(ServerSetupScreenTags.MAIN_BUTTON) @@ -143,9 +143,9 @@ class ServerSetupScreenTest : CoreComposeTest { switch.performClick() stubUiState.update { it.copy( - localCustomModel = true, - localModels = it.localModels.withNewState( - it.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }!!.copy( + localOnnxCustomModel = true, + localOnnxModels = it.localOnnxModels.withNewState( + it.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }!!.copy( selected = true, downloaded = true ), @@ -165,9 +165,9 @@ class ServerSetupScreenTest : CoreComposeTest { switch.performClick() stubUiState.update { it.copy( - localCustomModel = false, - localModels = it.localModels.withNewState( - it.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }!!.copy( + localOnnxCustomModel = false, + localOnnxModels = it.localOnnxModels.withNewState( + it.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }!!.copy( selected = false, ), ), diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index 12031b8d..37e5d759 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -12,7 +12,7 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest @@ -40,7 +40,7 @@ import org.junit.Test class ServerSetupViewModelTest : CoreViewModelTest() { private val stubGetConfigurationUseCase = mockk() - private val stubGetLocalAiModelsUseCase = mockk() + private val stubGetLocalOnnxModelsUseCase = mockk() private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubUrlValidator = mockk() private val stubCommonStringValidator = mockk() @@ -56,7 +56,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() { launchSource = LaunchSource.SETTINGS, dispatchersProvider = stubDispatchersProvider, getConfigurationUseCase = stubGetConfigurationUseCase, - getLocalAiModelsUseCase = stubGetLocalAiModelsUseCase, + getLocalOnnxModelsUseCase = stubGetLocalOnnxModelsUseCase, + fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, urlValidator = stubUrlValidator, stringValidator = stubCommonStringValidator, @@ -80,7 +81,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { } returns Single.just(Configuration(serverUrl = "https://5598.is.my.favorite.com")) every { - stubGetLocalAiModelsUseCase() + stubGetLocalOnnxModelsUseCase() } returns Single.just(mockLocalAiModels) every { @@ -98,7 +99,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { fun `initialized, expected UI state updated with correct stub values`() { val state = viewModel.state.value Assert.assertEquals(true, state.huggingFaceModels.isNotEmpty()) - Assert.assertEquals(true, state.localModels.isNotEmpty()) + Assert.assertEquals(true, state.localOnnxModels.isNotEmpty()) Assert.assertEquals("https://5598.is.my.favorite.com", state.serverUrl) Assert.assertEquals(ServerSetupState.AuthType.ANONYMOUS, state.authType) } @@ -123,8 +124,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() { selected = false, ) ) - Assert.assertEquals(true, state.localCustomModel) - Assert.assertEquals(expectedLocalModels, state.localModels) + Assert.assertEquals(true, state.localOnnxCustomModel) + Assert.assertEquals(expectedLocalModels, state.localOnnxModels) } @Test @@ -157,7 +158,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { val state = viewModel.state.value val expected = true - val actual = state.localModels.any { + val actual = state.localOnnxModels.any { it.downloadState == DownloadState.Downloading(22) } Assert.assertEquals(expected, actual) @@ -201,7 +202,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { val state = viewModel.state.value val expected = false - val actual = state.localModels.any { + val actual = state.localOnnxModels.any { it.downloadState == DownloadState.Downloading(22) } Assert.assertEquals(expected, actual) @@ -223,7 +224,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { runTest { val state = viewModel.state.value Assert.assertEquals(Modal.None, state.screenModal) - Assert.assertEquals(false, state.localModels.find { it.id == "1" }!!.downloaded) + Assert.assertEquals(false, state.localOnnxModels.find { it.id == "1" }!!.downloaded) } verify { stubDeleteModelUseCase("1") @@ -234,7 +235,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { fun `given received SelectLocalModel intent, expected passed LocalModel is selected in UI state`() { viewModel.processIntent(ServerSetupIntent.SelectLocalModel(mockServerSetupStateLocalModel)) val state = viewModel.state.value - Assert.assertEquals(true, state.localModels.find { it.id == "1" }!!.selected) + Assert.assertEquals(true, state.localOnnxModels.find { it.id == "1" }!!.selected) } @Test diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt index 5b00f390..269474df 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt @@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.Settings import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase @@ -44,7 +44,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest private val stubGetConfigurationUseCase = mockk() private val stubSelectStableDiffusionModelUseCase = mockk() private val stubGetStableDiffusionModelsUseCase = mockk() - private val stubObserveLocalAiModelsUseCase = mockk() + private val stubObserveLocalAiModelsUseCase = mockk() private val stubFetchAndGetStabilityAiEnginesUseCase = mockk() private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubFetchAndGetSwarmUiModelsUseCase = mockk() @@ -56,7 +56,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest getConfigurationUseCase = stubGetConfigurationUseCase, selectStableDiffusionModelUseCase = stubSelectStableDiffusionModelUseCase, getStableDiffusionModelsUseCase = stubGetStableDiffusionModelsUseCase, - observeLocalAiModelsUseCase = stubObserveLocalAiModelsUseCase, + observeLocalOnnxModelsUseCase = stubObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase = stubFetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, fetchAndGetSwarmUiModelsUseCase = stubFetchAndGetSwarmUiModelsUseCase, @@ -108,7 +108,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest selectedHfModel = "prompthero/openjourney-v4", stEngines = listOf("5598"), selectedStEngine = "5598", - localAiModels = listOf(LocalAiModel.CUSTOM), + localAiModels = listOf(LocalAiModel.CustomOnnx), selectedLocalAiModelId = "CUSTOM", swarmModels = listOf("5598"), selectedSwarmModel = "5598", @@ -232,13 +232,13 @@ class EngineSelectionViewModelTest : CoreViewModelTest mockInitialData(DataTestCase.Mock, ServerSource.LOCAL_MICROSOFT_ONNX) every { - stubPreferenceManager::localModelId.set(any()) + stubPreferenceManager::localOnnxModelId.set(any()) } returns Unit viewModel.processIntent(EngineSelectionIntent("llm_5598")) verify { - stubPreferenceManager::localModelId.set("llm_5598") + stubPreferenceManager::localOnnxModelId.set("llm_5598") } } @@ -262,7 +262,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest huggingFaceModel = "prompthero/openjourney-v4", stabilityAiEngineId = "5598", swarmUiModel = "5598", - localModelId = "CUSTOM", + localOnnxModelId = "CUSTOM", source = source, ), ) diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json new file mode 100644 index 00000000..7f7ac2a7 --- /dev/null +++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json @@ -0,0 +1,254 @@ +{ + "formatVersion": 1, + "database": { + "version": 6, + "identityHash": "6f6ccee56637122e0126c09bb3eb3fdc", + "entities": [ + { + "tableName": "generation_results", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `image_base_64` TEXT NOT NULL, `original_image_base_64` TEXT NOT NULL, `created_at` INTEGER NOT NULL, `generation_type` TEXT NOT NULL, `prompt` TEXT NOT NULL, `negative_prompt` TEXT NOT NULL, `width` INTEGER NOT NULL, `height` INTEGER NOT NULL, `sampling_steps` INTEGER NOT NULL, `cfg_scale` REAL NOT NULL, `restore_faces` INTEGER NOT NULL, `sampler` TEXT NOT NULL, `seed` TEXT NOT NULL, `sub_seed` TEXT NOT NULL DEFAULT '', `sub_seed_strength` REAL NOT NULL DEFAULT 0.0, `denoising_strength` REAL NOT NULL DEFAULT 0.0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "imageBase64", + "columnName": "image_base_64", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "originalImageBase64", + "columnName": "original_image_base_64", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "createdAt", + "columnName": "created_at", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "generationType", + "columnName": "generation_type", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "prompt", + "columnName": "prompt", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "negativePrompt", + "columnName": "negative_prompt", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "width", + "columnName": "width", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "height", + "columnName": "height", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "samplingSteps", + "columnName": "sampling_steps", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "cfgScale", + "columnName": "cfg_scale", + "affinity": "REAL", + "notNull": true + }, + { + "fieldPath": "restoreFaces", + "columnName": "restore_faces", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "sampler", + "columnName": "sampler", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "seed", + "columnName": "seed", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "subSeed", + "columnName": "sub_seed", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "subSeedStrength", + "columnName": "sub_seed_strength", + "affinity": "REAL", + "notNull": true, + "defaultValue": "0.0" + }, + { + "fieldPath": "denoisingStrength", + "columnName": "denoising_strength", + "affinity": "REAL", + "notNull": true, + "defaultValue": "0.0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "local_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `type` TEXT NOT NULL DEFAULT 'onnx', `name` TEXT NOT NULL, `size` TEXT NOT NULL, `sources` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "type", + "columnName": "type", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "'onnx'" + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "size", + "columnName": "size", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "sources", + "columnName": "sources", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "hugging_face_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `name` TEXT NOT NULL, `alias` TEXT NOT NULL, `source` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "alias", + "columnName": "alias", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "source", + "columnName": "source", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "supporters", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER NOT NULL, `name` TEXT NOT NULL, `date` INTEGER NOT NULL, `message` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "date", + "columnName": "date", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "message", + "columnName": "message", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + } + ], + "views": [], + "setupQueries": [ + "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '6f6ccee56637122e0126c09bb3eb3fdc')" + ] + } +} \ No newline at end of file diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt index df8eada6..1a59785c 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt @@ -4,18 +4,11 @@ import androidx.room.AutoMigration import androidx.room.Database import androidx.room.RoomDatabase import androidx.room.TypeConverters -import com.shifthackz.aisdv1.storage.converters.DateConverters -import com.shifthackz.aisdv1.storage.converters.ListConverters +import com.shifthackz.aisdv1.storage.converters.* import com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase.Companion.DB_VERSION -import com.shifthackz.aisdv1.storage.db.persistent.contract.GenerationResultContract -import com.shifthackz.aisdv1.storage.db.persistent.dao.GenerationResultDao -import com.shifthackz.aisdv1.storage.db.persistent.dao.HuggingFaceModelDao -import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao -import com.shifthackz.aisdv1.storage.db.persistent.dao.SupporterDao -import com.shifthackz.aisdv1.storage.db.persistent.entity.GenerationResultEntity -import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity -import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity -import com.shifthackz.aisdv1.storage.db.persistent.entity.SupporterEntity +import com.shifthackz.aisdv1.storage.db.persistent.contract.* +import com.shifthackz.aisdv1.storage.db.persistent.dao.* +import com.shifthackz.aisdv1.storage.db.persistent.entity.* @Database( version = DB_VERSION, @@ -46,6 +39,11 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.SupporterEntity * Added [SupporterEntity]. */ AutoMigration(from = 4, to = 5), + /** + * Added 1 field to [LocalModelEntity]: + * - [LocalModelContract.TYPE] + */ + AutoMigration(from = 5, to = 6), ], ) @TypeConverters( @@ -60,6 +58,6 @@ internal abstract class PersistentDatabase : RoomDatabase() { companion object { const val DB_NAME = "ai_sd_v1_storage_db" - const val DB_VERSION = 5 + const val DB_VERSION = 6 } } diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt index 03b3c03a..2ffebf43 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt @@ -4,6 +4,7 @@ object LocalModelContract { const val TABLE = "local_models" const val ID = "id" + const val TYPE = "type" const val NAME = "name" const val SIZE = "size" const val SOURCES = "sources" diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt index 435d4c69..a2b94a16 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt @@ -16,9 +16,15 @@ interface LocalModelDao { @Query("SELECT * FROM ${LocalModelContract.TABLE}") fun query(): Single> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.TYPE} = :type") + fun queryByType(type: String): Single> + @Query("SELECT * FROM ${LocalModelContract.TABLE}") fun observe(): Flowable> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.TYPE} = :type") + fun observeByType(type: String): Flowable> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.ID} = :id LIMIT 1") fun queryById(id: String): Single diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt index ab896641..d69856d2 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt @@ -10,6 +10,8 @@ data class LocalModelEntity( @PrimaryKey(autoGenerate = false) @ColumnInfo(name = LocalModelContract.ID) val id: String, + @ColumnInfo(name = LocalModelContract.TYPE, defaultValue = "onnx") + val type: String, @ColumnInfo(name = LocalModelContract.NAME) val name: String, @ColumnInfo(name = LocalModelContract.SIZE) From f00441139d589a73181145edaaf4207be8d93eac Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Tue, 27 Aug 2024 12:10:03 +0300 Subject: [PATCH 05/10] Update tests --- .../aisdv1/presentation/mocks/LocalAiModelMocks.kt | 1 + .../screen/setup/ServerSetupViewModelTest.kt | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt index cdf190e9..62884de9 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt @@ -8,6 +8,7 @@ val mockLocalAiModels = listOf( LocalAiModel.CustomOnnx, LocalAiModel( id = "1", + type = LocalAiModel.Type.ONNX, name = "Model 1", size = "5 Gb", sources = listOf("https://example.com/1.html"), diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index 37e5d759..57d73860 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -12,6 +12,7 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase @@ -41,6 +42,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { private val stubGetConfigurationUseCase = mockk() private val stubGetLocalOnnxModelsUseCase = mockk() + private val stubGetLocalMediaPipeModelsUseCase = mockk() private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubUrlValidator = mockk() private val stubCommonStringValidator = mockk() @@ -57,7 +59,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { dispatchersProvider = stubDispatchersProvider, getConfigurationUseCase = stubGetConfigurationUseCase, getLocalOnnxModelsUseCase = stubGetLocalOnnxModelsUseCase, - + getLocalMediaPipeModelsUseCase = stubGetLocalMediaPipeModelsUseCase, fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, urlValidator = stubUrlValidator, stringValidator = stubCommonStringValidator, @@ -84,6 +86,10 @@ class ServerSetupViewModelTest : CoreViewModelTest() { stubGetLocalOnnxModelsUseCase() } returns Single.just(mockLocalAiModels) + every { + stubGetLocalMediaPipeModelsUseCase() + } returns Single.just(emptyList()) + every { stubFetchAndGetHuggingFaceModelsUseCase() } returns Single.just(mockHuggingFaceModels) @@ -106,6 +112,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { @Test fun `given received AllowLocalCustomModel intent, expected Custom local model selected in UI state`() { + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) viewModel.processIntent(ServerSetupIntent.AllowLocalCustomModel(true)) val state = viewModel.state.value val expectedLocalModels = listOf( @@ -233,6 +240,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { @Test fun `given received SelectLocalModel intent, expected passed LocalModel is selected in UI state`() { + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) viewModel.processIntent(ServerSetupIntent.SelectLocalModel(mockServerSetupStateLocalModel)) val state = viewModel.state.value Assert.assertEquals(true, state.localOnnxModels.find { it.id == "1" }!!.selected) From c88e209b97b618aa4865d453d65c7569b951a821 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Tue, 27 Aug 2024 13:59:44 +0300 Subject: [PATCH 06/10] Fix unit tests --- .../DownloadableModelRepositoryImpl.kt | 20 ++++++++++++++----- .../DownloadableModelLocalDataSourceTest.kt | 16 +++++++-------- .../aisdv1/data/mocks/LocalAiModelMocks.kt | 1 + .../DownloadableModelRemoteDataSourceTest.kt | 14 ++++++++++++- .../DownloadableModelRepositoryImplTest.kt | 8 ++++++++ .../aisdv1/domain/mocks/LocalAiModelMocks.kt | 1 + .../generation/TextToImageUseCaseImplTest.kt | 3 +++ .../GetConfigurationUseCaseImplTest.kt | 8 ++++++++ .../SetServerConfigurationUseCaseImplTest.kt | 8 ++++++++ .../screen/settings/SettingsViewModel.kt | 2 ++ 10 files changed, 67 insertions(+), 14 deletions(-) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt index 90e11c7a..6f84189c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt @@ -1,11 +1,16 @@ package com.shifthackz.aisdv1.data.repository +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import io.reactivex.rxjava3.core.Single internal class DownloadableModelRepositoryImpl( private val remoteDataSource: DownloadableModelDataSource.Remote, private val localDataSource: DownloadableModelDataSource.Local, + private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelRepository { override fun download(id: String) = localDataSource @@ -22,11 +27,16 @@ internal class DownloadableModelRepositoryImpl( .andThen(localDataSource.getAllOnnx()) .onErrorResumeNext { localDataSource.getAllOnnx() } - override fun getAllMediaPipe() = remoteDataSource - .fetch() - .flatMapCompletable(localDataSource::save) - .andThen(localDataSource.getAllMediaPipe()) - .onErrorResumeNext { localDataSource.getAllMediaPipe() } + override fun getAllMediaPipe(): Single> { + if (buildInfoProvider.type == BuildType.FOSS) { + return Single.just(emptyList()) + } + return remoteDataSource + .fetch() + .flatMapCompletable(localDataSource::save) + .andThen(localDataSource.getAllMediaPipe()) + .onErrorResumeNext { localDataSource.getAllMediaPipe() } + } override fun observeAllOnnx() = localDataSource.observeAllOnnx() } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt index 73a5b0eb..90fe0cbe 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt @@ -40,7 +40,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns models list, app build type is PLAY, expected valid domain models list`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(mockLocalModelEntities) every { @@ -71,7 +71,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns empty models list, app build type is PLAY, expected empty domain models list`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(emptyList()) every { @@ -94,7 +94,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns models list, app build type is FOSS, expected valid domain models list with CUSTOM model included`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(mockLocalModelEntities) every { @@ -128,7 +128,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns empty models list, app build type is FOSS, expected domain models list with only CUSTOM model included`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(emptyList()) every { @@ -151,7 +151,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao throws exception, expected error value`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.error(stubException) localDataSource @@ -273,7 +273,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is PLAY, expected empty list, then domain list with two items`() { every { - stubDao.observe() + stubDao.observeByType(any()) } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) every { @@ -308,7 +308,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is FOSS, expected list with only CUSTOM model included, then domain list with two items and CUSTOM`() { every { - stubDao.observe() + stubDao.observeByType(any()) } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) every { @@ -346,7 +346,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to observe all models, dao throws exception, expected error value`() { every { - stubDao.observe() + stubDao.observeByType(any()) } returns Flowable.error(stubException) localDataSource diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt index 6eee82bc..9aaa20a8 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel val mockLocalAiModel = LocalAiModel( id = "5598", + type = LocalAiModel.Type.ONNX, name = "Model 5598", size = "5 Gb", sources = listOf("https://example.com/1.html"), diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt index 2b57a295..775fd5a6 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt @@ -26,7 +26,13 @@ class DownloadableModelRemoteDataSourceTest { whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.just(mockDownloadableModelsResponse)) - val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) + whenever(stubApi.fetchMediaPipeModels()) + .thenReturn(Single.just(mockDownloadableModelsResponse)) + + val expected = listOf( + mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX), + mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe), + ).flatten() remoteDataSource .fetch() @@ -42,6 +48,9 @@ class DownloadableModelRemoteDataSourceTest { whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.just(emptyList())) + whenever(stubApi.fetchMediaPipeModels()) + .thenReturn(Single.just(emptyList())) + remoteDataSource .fetch() .test() @@ -56,6 +65,9 @@ class DownloadableModelRemoteDataSourceTest { whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.error(stubException)) + whenever(stubApi.fetchMediaPipeModels()) + .thenReturn(Single.error(stubException)) + remoteDataSource .fetch() .test() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt index acfb8f74..379bc3c5 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.data.repository +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.data.mocks.mockLocalAiModel import com.shifthackz.aisdv1.data.mocks.mockLocalAiModels import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource @@ -25,14 +27,20 @@ class DownloadableModelRepositoryImplTest { private val stubDownloadState = BehaviorSubject.create() private val stubRemoteDataSource = mockk() private val stubLocalDataSource = mockk() + private val stubBuildInfoProvider = mockk() private val repository = DownloadableModelRepositoryImpl( remoteDataSource = stubRemoteDataSource, localDataSource = stubLocalDataSource, + buildInfoProvider = stubBuildInfoProvider, ) @Before fun initialize() { + every { + stubBuildInfoProvider.type + } returns BuildType.FULL + every { stubLocalDataSource.observeAllOnnx() } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt index d3c53d8a..5ce3fa3b 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt @@ -6,6 +6,7 @@ val mockLocalAiModels = listOf( LocalAiModel.CustomOnnx, LocalAiModel( id = "1", + type = LocalAiModel.Type.ONNX, name = "Model 1", size = "5 Gb", sources = listOf("https://example.com/1.html"), diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index b028176d..98505eac 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository @@ -27,6 +28,7 @@ class TextToImageUseCaseImplTest { private val stubStabilityAiGenerationRepository = mock() private val stubSwarmUiGenerationRepository = mock() private val stubLocalDiffusionGenerationRepository = mock() + private val stubMediaPipeGenerationRepository = mock() private val stubPreferenceManager = mock() private val useCase = TextToImageUseCaseImpl( @@ -37,6 +39,7 @@ class TextToImageUseCaseImplTest { stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, swarmUiGenerationRepository = stubSwarmUiGenerationRepository, + mediaPipeGenerationRepository = stubMediaPipeGenerationRepository, preferenceManager = stubPreferenceManager, ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index c8bb1cd2..5a5d12bc 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -76,6 +76,14 @@ class GetConfigurationUseCaseImplTest { stubPreferenceManager::localOnnxCustomModelPath.get() } returns mockConfiguration.localOnnxModelPath + every { + stubPreferenceManager::localMediaPipeModelId.get() + } returns mockConfiguration.localMediaPipeModelId + + every { + stubPreferenceManager::localMediaPipeCustomModelPath.get() + } returns mockConfiguration.localMediaPipeModelPath + useCase .invoke() .test() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index ed48a9d5..a1620120 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -75,6 +75,14 @@ class SetServerConfigurationUseCaseImplTest { stubPreferenceManager::localOnnxCustomModelPath.set(any()) } returns Unit + every { + stubPreferenceManager::localMediaPipeModelId.set(any()) + } returns Unit + + every { + stubPreferenceManager::localMediaPipeCustomModelPath.set(any()) + } returns Unit + useCase .invoke(mockConfiguration) .test() diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt index 66a2c97a..d7756582 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt @@ -25,6 +25,7 @@ import com.shifthackz.aisdv1.presentation.screen.debug.DebugMenuAccessor import com.shifthackz.aisdv1.presentation.screen.drawer.DrawerIntent import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.kotlin.subscribeBy +import java.util.concurrent.TimeUnit import com.shifthackz.aisdv1.core.localization.R as LocalizationR class SettingsViewModel( @@ -48,6 +49,7 @@ class SettingsViewModel( private val appVersionProducer = Flowable.fromCallable { buildInfoProvider.toString() } private val sdModelsProducer = getStableDiffusionModelsUseCase() + .timeout(10L, TimeUnit.SECONDS) .toFlowable() .onErrorReturn { emptyList() } From d87469b4cfca173bab491eeaa00689acf514e31d Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Tue, 27 Aug 2024 18:21:05 +0300 Subject: [PATCH 07/10] Add project icon --- .gitignore | 1 + .idea/icon.svg | 106 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 .idea/icon.svg diff --git a/.gitignore b/.gitignore index 5076b36e..9590699f 100755 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ captures/ !/.idea/vcs.xml !/.idea/fileTemplates/ !/.idea/inspectionProfiles/ +!/.idea/icon.svg !/.idea/scopes/ !/.idea/codeStyleSettings.xml !/.idea/encodings.xml diff --git a/.idea/icon.svg b/.idea/icon.svg new file mode 100644 index 00000000..7f7eab01 --- /dev/null +++ b/.idea/icon.svg @@ -0,0 +1,106 @@ + + + + + + + + + + + + + + + + + + From 35b4495f84f82d3297f847dfd6e050bfaad9fac5 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Wed, 28 Aug 2024 00:40:13 +0300 Subject: [PATCH 08/10] Optimize BL, translations --- .../src/main/res/values-ru/strings.xml | 4 +- .../src/main/res/values-tr/strings.xml | 4 +- .../src/main/res/values-uk/strings.xml | 4 +- .../src/main/res/values-zh/strings.xml | 4 +- .../src/main/res/values/strings.xml | 5 +- .../local/DownloadableModelLocalDataSource.kt | 6 +- ...alDiffusionGenerationRepositoryImplTest.kt | 11 +- .../domain/entity/LocalDiffusionStatus.kt | 6 + .../feature/diffusion/LocalDiffusion.kt | 5 +- .../LocalDiffusionGenerationRepository.kt | 4 +- .../MediaPipeGenerationRepository.kt | 2 + ...serveLocalDiffusionProcessStatusUseCase.kt | 4 +- ...alDiffusionProcessStatusUseCaseImplTest.kt | 26 +-- .../feature/diffusion/LocalDiffusionImpl.kt | 5 +- .../core/GenerationMviViewModel.kt | 3 +- .../aisdv1/presentation/model/Modal.kt | 3 +- .../navigation/router/main/MainRouterImpl.kt | 2 +- .../gallery/list/GalleryPagingSource.kt | 2 +- .../screen/img2img/ImageToImageScreen.kt | 58 ++++-- .../screen/img2img/ImageToImageState.kt | 2 +- .../loader/ConfigurationLoaderViewModel.kt | 2 + .../screen/setup/ServerSetupScreen.kt | 4 +- .../screen/setup/ServerSetupState.kt | 99 +++++++++- .../screen/setup/ServerSetupViewModel.kt | 186 ++++-------------- .../screen/setup/forms/LocalDiffusionForm.kt | 160 ++++++++------- .../screen/txt2img/TextToImageViewModel.kt | 13 +- .../core/CoreGenerationMviViewModelTest.kt | 4 +- .../screen/setup/ServerSetupViewModelTest.kt | 3 +- 28 files changed, 342 insertions(+), 289 deletions(-) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt diff --git a/core/localization/src/main/res/values-ru/strings.xml b/core/localization/src/main/res/values-ru/strings.xml index acbaa8e7..ef0e1272 100644 --- a/core/localization/src/main/res/values-ru/strings.xml +++ b/core/localization/src/main/res/values-ru/strings.xml @@ -133,9 +133,11 @@ Укажите свйой URL-адрес Swarm UI Модульный веб-интерфейс Stable Diffusion, в котором особое внимание уделяется обеспечению легкого доступа к инструментам, высокой производительности и расширяемости. - Эта конфигурация позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. + Эта конфигурация использует Microsoft ONNX и позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. ВНИМАНИЕ! Функциональность Local Diffusion в бета-тестировании. Не ожидайте высококачественных изображений в локальном режиме. \n\nЭта реализация может не работать должным образом на мобильных телефонах. Производительность и скорость генерации зависят от ресурсов вашего телефона (ЦП, ОЗУ) и размера сгенерированного изображения (чем меньше размер изображения, тем быстрее генерируется). + Эта конфигурация использует Google AI MediaPipe и позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. + Веб Txt2Img diff --git a/core/localization/src/main/res/values-tr/strings.xml b/core/localization/src/main/res/values-tr/strings.xml index eed6dea1..8319a6c1 100644 --- a/core/localization/src/main/res/values-tr/strings.xml +++ b/core/localization/src/main/res/values-tr/strings.xml @@ -133,9 +133,11 @@ Swarm UI URL\'nizi sağlayın Araçlara kolay erişim, yüksek performans ve genişletilebilirlik üzerine odaklanan Modüler, Kararlı Yaygın Web Kullanıcı Arayüzü. - Bu yapılandırma, telefonunuzda uzak sunucuya/buluta bağlanmaya gerek kalmadan Stable Diffusion AI nesillerini çalıştırmanıza izin verir. + Bu yapılandırma Microsoft ONNX çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır. Uyarı! Yerel Yayılma işlevi beta testindedir. Yerel modu kullanarak yüksek kaliteli görüntüler beklemeyin. \n\nBu uygulama, güçlü olmayan telefonlarda iyi çalışmayabilir. Oluşturma performansı ve hızı, telefonunuzun kaynaklarına (CPU, RAM) ve oluşturulan görüntünün boyutuna bağlıdır (görüntü boyutu ne kadar küçükse, oluşturma o kadar hızlı olur). + Bu yapılandırma Google AI MediaPipe çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır. + Web arayüzü Txt2Img diff --git a/core/localization/src/main/res/values-uk/strings.xml b/core/localization/src/main/res/values-uk/strings.xml index da32118e..87fe4b5d 100644 --- a/core/localization/src/main/res/values-uk/strings.xml +++ b/core/localization/src/main/res/values-uk/strings.xml @@ -133,9 +133,11 @@ Provide your Swarm UI URL Модульний веб-інтерфейс Stable Diffusion з наголосом на полегшення доступу до інструментів, високу продуктивність і розширюваність. - Ця конфігурація дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. + Ця конфігурація використовує Microsoft ONNX та дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. УВАГА! Функціональність Local Diffusion у бета-тестуванні. Не очікуйте високоякісних зображень у локальному режимі. \n\nЦя реалізація може не працювати належним чином на телефонах із слабкою потужністю. Продуктивність і швидкість генерації залежать від ресурсів вашого телефону (ЦП, ОЗУ) і розміру згенерованого зображення (чим менший розмір зображення, тим швидше генерується). + Ця конфігурація використовує Google AI MediaPipe та дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. + Веб Txt2Img diff --git a/core/localization/src/main/res/values-zh/strings.xml b/core/localization/src/main/res/values-zh/strings.xml index c3483d2e..7329dc23 100644 --- a/core/localization/src/main/res/values-zh/strings.xml +++ b/core/localization/src/main/res/values-zh/strings.xml @@ -167,9 +167,11 @@ 本地扩散 - 此配置允许在您的手机上运行Stable Diffusion AI生成,无需连接到远程服务器/云。 + 此配置使用 Microsoft ONNX 运行时,并允许在手机上运行稳定的 Diffusion AI 生成,无需连接到远程服务器/云。 警告!本地扩散功能处于测试版。不要期望使用本地模式获得高质量图像。 \n\n此实现可能在不强大的手机上运行不佳。生成性能和速度取决于您的手机资源(CPU、RAM)和生成的图像大小(图像越小,生成越快)。 + 此配置使用 Google AI MediaPipe 运行时,并允许在手机上运行稳定的 Diffusion AI 生成,无需连接到远程服务器/云。 + 网络界面 diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml index 1ed2ae87..b591977e 100755 --- a/core/localization/src/main/res/values/strings.xml +++ b/core/localization/src/main/res/values/strings.xml @@ -72,7 +72,7 @@ Horde Local Diffusion Microsoft ONNX (Beta) ONNX - Local Google AI MediaPipe (Beta) + Local Diffusion Google AI MediaPipe (Beta) MediaPipe Hugging Face Inference HuggingFace @@ -153,11 +153,10 @@ A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility. Local Diffusion Microsoft ONNX - This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). - Local Google AI MediaPipe + Local Diffusion Google AI MediaPipe This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Web UI diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index d5fdcb66..4ab031c2 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -92,13 +92,15 @@ internal class DownloadableModelLocalDataSource( LocalAiModel.CustomMediaPipe.id -> emitter.onSuccess(true) else -> { - val files = getLocalModelFiles(model.id) + when (model.type) { LocalAiModel.Type.ONNX -> { + val files = getLocalModelFiles(model.id).filter { it.isDirectory } emitter.onSuccess(files.size == 4) } LocalAiModel.Type.MediaPipe -> { + val files = getLocalModelFiles(model.id) emitter.onSuccess(files.isNotEmpty()) } } @@ -116,7 +118,7 @@ internal class DownloadableModelLocalDataSource( private fun getLocalModelFiles(id: String): List { val localModelDir = getLocalModelDirectory(id) if (!localModelDir.exists()) return emptyList() - return localModelDir.listFiles()?.filter { it.isDirectory } ?: emptyList() + return localModelDir.listFiles()?.toList() ?: emptyList() } private fun List.withLocalData() = Observable diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt index 88cfd515..126edac1 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway @@ -31,7 +32,7 @@ class LocalDiffusionGenerationRepositoryImplTest { private val stubBitmap = mockk() private val stubException = Throwable("Something went wrong.") - private val stubStatus = BehaviorSubject.create() + private val stubStatus = BehaviorSubject.create() private val stubMediaStoreGateway = mockk() private val stubBase64ToBitmapConverter = mockk() private val stubBitmapToBase64Converter = mockk() @@ -83,17 +84,17 @@ class LocalDiffusionGenerationRepositoryImplTest { fun `given attempt to observe status, local emits two values, expected same values with same order`() { val stubObserver = repository.observeStatus().test() - stubStatus.onNext(LocalDiffusion.Status(1, 2)) + stubStatus.onNext(LocalDiffusionStatus(1, 2)) stubObserver .assertNoErrors() - .assertValueAt(0, LocalDiffusion.Status(1, 2)) + .assertValueAt(0, LocalDiffusionStatus(1, 2)) - stubStatus.onNext(LocalDiffusion.Status(2, 2)) + stubStatus.onNext(LocalDiffusionStatus(2, 2)) stubObserver .assertNoErrors() - .assertValueAt(1, LocalDiffusion.Status(2, 2)) + .assertValueAt(1, LocalDiffusionStatus(2, 2)) } @Test diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt new file mode 100644 index 00000000..779d8bc6 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt @@ -0,0 +1,6 @@ +package com.shifthackz.aisdv1.domain.entity + +data class LocalDiffusionStatus( + val current: Int, + val total: Int, +) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt index d1f260e7..afcdefd0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.domain.feature.diffusion import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable @@ -9,7 +10,5 @@ import io.reactivex.rxjava3.core.Single interface LocalDiffusion { fun process(payload: TextToImagePayload): Single fun interrupt(): Completable - fun observeStatus(): Observable - - data class Status(val current: Int, val total: Int) + fun observeStatus(): Observable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt index 0e109843..3ed5c370 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt @@ -1,14 +1,14 @@ package com.shifthackz.aisdv1.domain.repository import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single interface LocalDiffusionGenerationRepository { - fun observeStatus(): Observable + fun observeStatus(): Observable fun generateFromText(payload: TextToImagePayload): Single fun interruptGeneration(): Completable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt index ac6088e5..a7b65fb5 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt @@ -1,7 +1,9 @@ package com.shifthackz.aisdv1.domain.repository import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single interface MediaPipeGenerationRepository { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt index 3a65f91b..9530b6b0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.generation -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import io.reactivex.rxjava3.core.Observable interface ObserveLocalDiffusionProcessStatusUseCase { - operator fun invoke(): Observable + operator fun invoke(): Observable } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt index 528e2e6e..47836c4e 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.domain.usecase.generation import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.subjects.BehaviorSubject @@ -12,7 +12,7 @@ import org.junit.Test class ObserveLocalDiffusionProcessStatusUseCaseImplTest { private val stubException = Throwable("Error loading Local Diffusion.") - private val stubLocalStatus = BehaviorSubject.create() + private val stubLocalStatus = BehaviorSubject.create() private val stubRepository = mock() private val useCase = ObserveLocalDiffusionProcessStatusUseCaseImpl(stubRepository) @@ -27,23 +27,23 @@ class ObserveLocalDiffusionProcessStatusUseCaseImplTest { fun `given repository processes three steps, expected three valid status values`() { val stubObserver = useCase().test() - stubLocalStatus.onNext(LocalDiffusion.Status(1, 3)) + stubLocalStatus.onNext(LocalDiffusionStatus(1, 3)) stubObserver .assertNoErrors() - .assertValueAt(0, LocalDiffusion.Status(1, 3)) + .assertValueAt(0, LocalDiffusionStatus(1, 3)) - stubLocalStatus.onNext(LocalDiffusion.Status(2, 3)) + stubLocalStatus.onNext(LocalDiffusionStatus(2, 3)) stubObserver .assertNoErrors() - .assertValueAt(1, LocalDiffusion.Status(2, 3)) + .assertValueAt(1, LocalDiffusionStatus(2, 3)) - stubLocalStatus.onNext(LocalDiffusion.Status(3, 3)) + stubLocalStatus.onNext(LocalDiffusionStatus(3, 3)) stubObserver .assertNoErrors() - .assertValueAt(2, LocalDiffusion.Status(3, 3)) + .assertValueAt(2, LocalDiffusionStatus(3, 3)) .assertValueCount(3) } @@ -51,23 +51,23 @@ class ObserveLocalDiffusionProcessStatusUseCaseImplTest { fun `given repository processes two steps, emits same step twice, expected two valid status values`() { val stubObserver = useCase().test() - stubLocalStatus.onNext(LocalDiffusion.Status(1, 2)) + stubLocalStatus.onNext(LocalDiffusionStatus(1, 2)) stubObserver .assertNoErrors() - .assertValueAt(0, LocalDiffusion.Status(1, 2)) + .assertValueAt(0, LocalDiffusionStatus(1, 2)) - stubLocalStatus.onNext(LocalDiffusion.Status(1, 2)) + stubLocalStatus.onNext(LocalDiffusionStatus(1, 2)) stubObserver .assertNoErrors() .assertValueCount(1) - stubLocalStatus.onNext(LocalDiffusion.Status(2, 2)) + stubLocalStatus.onNext(LocalDiffusionStatus(2, 2)) stubObserver .assertNoErrors() - .assertValueAt(1, LocalDiffusion.Status(2, 2)) + .assertValueAt(1, LocalDiffusionStatus(2, 2)) .assertValueCount(2) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt index 56b3f39d..775cad2c 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt @@ -4,6 +4,7 @@ import ai.onnxruntime.OnnxTensor import android.graphics.Bitmap import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.TAG @@ -20,7 +21,7 @@ internal class LocalDiffusionImpl( private val ortEnvironmentProvider: OrtEnvironmentProvider, ) : LocalDiffusion { - private val statusSubject: PublishSubject = PublishSubject.create() + private val statusSubject: PublishSubject = PublishSubject.create() override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> try { @@ -31,7 +32,7 @@ internal class LocalDiffusionImpl( uNet.setCallback(object : UNet.Callback { override fun onStep(maxStep: Int, step: Int) { debugLog(TAG, "Received step update: ${maxStep}/${step}") - statusSubject.onNext(LocalDiffusion.Status(step, maxStep)) + statusSubject.onNext(LocalDiffusionStatus(step, maxStep)) } override fun onBuildImage(status: Int, bitmap: Bitmap?) { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt index 714f0ef9..689fc79c 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt @@ -9,6 +9,7 @@ import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.OpenAiSize import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.StabilityAiSampler @@ -118,7 +119,7 @@ abstract class GenerationMviViewModel? get() = status?.let { (current, total) -> current to total } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt index 125eb82c..a9a2a438 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt @@ -47,7 +47,7 @@ internal class MainRouterImpl : MainRouter { override fun navigateToServerSetup(source: LaunchSource) { effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.ordinal}") { if (source == LaunchSource.SPLASH) { - popUpTo(Constants.ROUTE_SPLASH) { + popUpTo(0) { inclusive = true } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt index d5ad617a..482ea2d9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt @@ -32,7 +32,7 @@ class GalleryPagingSource( limit = pageSize, offset = pageNext * Constants.PAGINATION_PAYLOAD_SIZE, ) - .subscribeOn(schedulersProvider.io) + .subscribeOn(schedulersProvider.computation) .flatMapObservable { Observable.fromIterable(it) } .map { ai -> ai.id to ai.image } .map { (id, base64) -> id to Input(base64) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt index 2eee81b6..000bca56 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt @@ -54,6 +54,9 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.math.roundTo +import com.shifthackz.aisdv1.core.model.UiText +import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.ui.MviComponent import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ServerSource @@ -233,17 +236,33 @@ private fun ScreenContent( ) Text( modifier = Modifier.padding(top = 14.dp), - text = stringResource( - if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) LocalizationR.string.local_no_img2img_support_sub_title - else LocalizationR.string.dalle_no_img2img_support_sub_title - ), + text = when (state.mode) { + ServerSource.OPEN_AI -> LocalizationR.string + .dalle_no_img2img_support_sub_title + .asUiText() + + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string + .local_no_img2img_support_sub_title + .asUiText() + + else -> UiText.empty + }.asString(), ) Text( modifier = Modifier.padding(top = 14.dp), - text = stringResource( - if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) LocalizationR.string.local_no_img2img_support_sub_title_2 - else LocalizationR.string.dalle_no_img2img_support_sub_title_2 - ), + text = when (state.mode) { + ServerSource.OPEN_AI -> LocalizationR.string + .dalle_no_img2img_support_sub_title_2 + .asUiText() + + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string + .local_no_img2img_support_sub_title_2 + .asUiText() + + else -> UiText.empty + }.asString(), ) } } @@ -252,6 +271,7 @@ private fun ScreenContent( bottomBar = { val isEnabled = when (state.mode) { ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.OPEN_AI -> true else -> !state.hasValidationErrors && !state.imageState.isEmpty @@ -271,20 +291,28 @@ private fun ScreenContent( keyboardController?.hide() when (state.mode) { ServerSource.OPEN_AI, - ServerSource.LOCAL_MICROSOFT_ONNX -> processIntent(GenerationMviIntent.Configuration) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + ServerSource.LOCAL_MICROSOFT_ONNX -> processIntent( + GenerationMviIntent.Configuration + ) else -> { promptChipTextFieldState.value.text.takeIf(String::isNotBlank) ?.let { "${state.prompt}, ${it.trim()}" } ?.let(GenerationMviIntent.Update::Prompt) ?.let(processIntent::invoke) - ?.also { promptChipTextFieldState.value = TextFieldValue("") } + ?.also { + promptChipTextFieldState.value = TextFieldValue("") + } negativePromptChipTextFieldState.value.text.takeIf(String::isNotBlank) ?.let { "${state.negativePrompt}, ${it.trim()}" } ?.let(GenerationMviIntent.Update::NegativePrompt) ?.let(processIntent::invoke) - ?.also { negativePromptChipTextFieldState.value = TextFieldValue("") } + ?.also { + negativePromptChipTextFieldState.value = + TextFieldValue("") + } processIntent(GenerationMviIntent.Generate) } @@ -292,8 +320,11 @@ private fun ScreenContent( }, enabled = isEnabled, ) { - if (state.mode != ServerSource.LOCAL_MICROSOFT_ONNX) { - Icon( + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit + + else -> Icon( modifier = Modifier.size(18.dp), imageVector = Icons.Default.AutoFixNormal, contentDescription = "Imagine", @@ -304,6 +335,7 @@ private fun ScreenContent( text = stringResource( id = when (state.mode) { ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.OPEN_AI -> LocalizationR.string.action_change_configuration else -> LocalizationR.string.action_generate diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt index 533b7b9b..3e301a66 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt @@ -90,7 +90,7 @@ data class ImageToImageState( heightValidationError: UiText?, nsfw: Boolean, batchCount: Int, - generateButtonEnabled: Boolean + generateButtonEnabled: Boolean, ): GenerationMviState = copy( onBoardingDemo = onBoardingDemo, screenModal = screenModal, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt index c97b3966..e3833534 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter import com.shifthackz.android.core.mvi.EmptyEffect import com.shifthackz.android.core.mvi.EmptyIntent import io.reactivex.rxjava3.kotlin.subscribeBy +import java.util.concurrent.TimeUnit import com.shifthackz.aisdv1.core.localization.R as LocalizationR class ConfigurationLoaderViewModel( @@ -28,6 +29,7 @@ class ConfigurationLoaderViewModel( init { !dataPreLoaderUseCase() + .timeout(15L, TimeUnit.SECONDS) .doOnSubscribe { updateState { ConfigurationLoaderState.StatusNotification( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index f4e5b0f2..612b4c90 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -172,7 +172,9 @@ fun ServerSetupScreenContent( id = when (state.step) { ServerSetupState.Step.SOURCE -> LocalizationR.string.next else -> when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.action_setup + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string + .action_setup else -> LocalizationR.string.action_connect } }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 639b3c63..44a37715 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -5,9 +5,11 @@ import com.shifthackz.aisdv1.core.common.links.LinksProvider import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants import com.shifthackz.android.core.mvi.MviState import org.koin.core.component.KoinComponent @@ -59,6 +61,13 @@ data class ServerSetupState( localMediaPipeCustomModel } + val localCustomModelPath: String + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localOnnxCustomModelPath + } else { + localMediaPipeCustomModelPath + } + val localModels: List get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { localOnnxModels @@ -85,13 +94,101 @@ data class ServerSetupState( ) fun withCredentials(value: AuthorizationCredentials) = when (value) { - is AuthorizationCredentials.HttpBasic -> this.copy( + is AuthorizationCredentials.HttpBasic -> copy( login = value.login, password = value.password, ) + AuthorizationCredentials.None -> this } + fun withLocalCustomModelPath(value: String): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + localOnnxCustomModelPath = value, + localCustomOnnxPathValidationError = null, + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + localMediaPipeCustomModelPath = value, + localCustomMediaPipePathValidationError = null, + ) + + else -> this + } + + fun withUpdatedLocalModel(value: LocalModel): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + localOnnxModels = localOnnxModels.withNewState(value) + ) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + localMediaPipeModels = localMediaPipeModels.withNewState(value) + ) + else -> this + } + + fun withDeletedLocalModel(value: LocalModel): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + screenModal = Modal.None, + localOnnxModels = localOnnxModels.withNewState( + value.copy( + downloadState = DownloadState.Unknown, + downloaded = false, + ), + ) + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + screenModal = Modal.None, + localMediaPipeModels = localMediaPipeModels.withNewState( + value.copy( + downloadState = DownloadState.Unknown, + downloaded = false, + ), + ) + ) + + else -> copy(screenModal = Modal.None) + } + + fun withSelectedLocalModel(value: LocalModel): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + localOnnxModels = localOnnxModels.withNewState( + value.copy(selected = true), + ), + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + localMediaPipeModels = localMediaPipeModels.withNewState( + value.copy(selected = true), + ), + ) + + else -> this + } + + fun withAllowCustomModel(value: Boolean): ServerSetupState { + fun List.updateCustomModelSelection(id: String) = withNewState( + find { m -> m.id == id }?.copy(selected = value) + ) + return when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> this.copy( + localOnnxCustomModel = value, + localOnnxModels = localOnnxModels.updateCustomModelSelection( + id = LocalAiModel.CustomOnnx.id, + ), + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> this.copy( + localMediaPipeCustomModel = value, + localMediaPipeModels = localMediaPipeModels.updateCustomModelSelection( + id = LocalAiModel.CustomMediaPipe.id, + ), + ) + + else -> this + } + } + enum class Step { SOURCE, CONFIGURE; diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 68f3c47b..d054a301 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -14,7 +14,6 @@ import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel -import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials import com.shifthackz.aisdv1.domain.interactor.settings.SetupConnectionInterActor @@ -33,7 +32,6 @@ import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomMediaPipeSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomOnnxSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi -import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.disposables.Disposable @@ -124,27 +122,7 @@ class ServerSetupViewModel( override fun processIntent(intent: ServerSetupIntent) = when (intent) { is ServerSetupIntent.AllowLocalCustomModel -> updateState { state -> - when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> state.copy( - localOnnxCustomModel = intent.allow, - localOnnxModels = state.localOnnxModels.withNewState( - state.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }?.copy( - selected = intent.allow, - ), - ), - ) - - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.copy( - localMediaPipeCustomModel = intent.allow, - localMediaPipeModels = state.localMediaPipeModels.withNewState( - state.localMediaPipeModels.find { m -> m.id == LocalAiModel.CustomMediaPipe.id}?.copy( - selected = intent.allow - ) - ) - ) - - else -> state - } + state.withAllowCustomModel(intent.allow) } ServerSetupIntent.DismissDialog -> setScreenModal(Modal.None) @@ -155,35 +133,11 @@ class ServerSetupViewModel( !deleteModelUseCase(intent.model.id) .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) - it.copy( - screenModal = Modal.None, - localOnnxModels = currentState.localOnnxModels.withNewState( - intent.model.copy( - downloadState = DownloadState.Unknown, - downloaded = false, - ), - ), - ) + it.withDeletedLocalModel(intent.model) } - is ServerSetupIntent.SelectLocalModel -> { - updateState { state -> - when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> state.copy( - localOnnxModels = state.localOnnxModels.withNewState( - intent.model.copy(selected = true), - ), - ) - - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.copy( - localMediaPipeModels = state.localMediaPipeModels.withNewState( - intent.model.copy(selected = true), - ), - ) - - else -> state - } - } + is ServerSetupIntent.SelectLocalModel -> updateState { state -> + state.withSelectedLocalModel(intent.model) } ServerSetupIntent.MainButtonClick -> when (currentState.step) { @@ -270,19 +224,7 @@ class ServerSetupViewModel( ServerSetupIntent.ConnectToLocalHost -> connectToServer() is ServerSetupIntent.SelectLocalModelPath -> updateState { state -> - when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> state.copy( - localOnnxCustomModelPath = intent.value, - localCustomOnnxPathValidationError = null, - ) - - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.copy( - localMediaPipeCustomModelPath = intent.value, - localCustomMediaPipePathValidationError = null, - ) - - else -> state - } + state.withLocalCustomModelPath(intent.value) } } @@ -471,126 +413,70 @@ class ServerSetupViewModel( return setupConnectionInterActor.connectToMediaPipe(localModelId) } - private fun localModelDownloadClickReducer(localModel: ServerSetupState.LocalModel) { + private fun localModelDownloadClickReducer(value: ServerSetupState.LocalModel) { + fun localModel(): ServerSetupState.LocalModel = + currentState.localModels.firstOrNull { it.id == value.id } + ?.let { value.copy(selected = it.selected) } + ?: value + when { // User cancels download - localModel.downloadState is DownloadState.Downloading -> { - val index = downloadDisposables.indexOfFirst { it.first == localModel.id } + localModel().downloadState is DownloadState.Downloading -> { + val index = downloadDisposables.indexOfFirst { it.first == localModel().id } if (index != -1) { downloadDisposables[index].second.dispose() downloadDisposables.removeAt(index) } - !deleteModelUseCase(localModel.id) + !deleteModelUseCase(localModel().id) .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) updateState { state -> - when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> { - state.copy( - localOnnxModels = state.localOnnxModels.withNewState( - localModel.copy(downloadState = DownloadState.Unknown), - ), - ) - } - - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { - state.copy( - localMediaPipeModels = state.localMediaPipeModels.withNewState( - localModel.copy(downloadState = DownloadState.Unknown), - ), - ) - } - - else -> state - } + state.withUpdatedLocalModel( + value = localModel().copy(downloadState = DownloadState.Unknown), + ) } } // User deletes local model - localModel.downloaded -> updateState { - it.copy(screenModal = Modal.DeleteLocalModelConfirm(localModel)) + localModel().downloaded -> updateState { + it.copy(screenModal = Modal.DeleteLocalModelConfirm(localModel())) } // User requested new download operation else -> { updateState { state -> - when (state.mode) { - ServerSource.LOCAL_MICROSOFT_ONNX -> { - state.copy( - localOnnxModels = state.localOnnxModels.withNewState( - localModel.copy( - downloadState = DownloadState.Downloading(), - ), - ), - ) - } - - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { - state.copy( - localMediaPipeModels = state.localMediaPipeModels.withNewState( - localModel.copy( - downloadState = DownloadState.Downloading(), - ), - ), - ) - } - - else -> state - } + state.withUpdatedLocalModel( + localModel().copy(downloadState = DownloadState.Downloading()), + ) } - !downloadModelUseCase(localModel.id) + !downloadModelUseCase(localModel().id) .distinctUntilChanged() .doOnSubscribe { wakeLockInterActor.acquireWakelockUseCase() } .doFinally { wakeLockInterActor.releaseWakeLockUseCase() } - .subscribeOnMainThread(schedulersProvider).subscribeBy( + .subscribeOnMainThread(schedulersProvider) + .subscribeBy( onError = { t -> errorLog(t) val message = t.localizedMessage ?: "Error" updateState { state -> - state.copy( - localOnnxModels = state.localOnnxModels.withNewState( - localModel.copy( - downloadState = DownloadState.Error(t), - ), + state.withUpdatedLocalModel( + localModel().copy( + downloadState = DownloadState.Error(t), ), - localMediaPipeModels = state.localMediaPipeModels.withNewState( - localModel.copy( - downloadState = DownloadState.Error(t), - ), - ) ) } setScreenModal(Modal.Error(message.asUiText())) }, onNext = { downloadState -> - updateState { - when (downloadState) { - is DownloadState.Complete -> it.copy( - localOnnxModels = it.localOnnxModels.withNewState( - localModel.copy( - downloadState = downloadState, - downloaded = true, - ), - ), - localMediaPipeModels = it.localMediaPipeModels.withNewState( - localModel.copy( - downloadState = downloadState, - downloaded = true, - ), - ), - ) - - else -> it.copy( - localOnnxModels = it.localOnnxModels.withNewState( - localModel.copy(downloadState = downloadState), - ), - localMediaPipeModels = it.localMediaPipeModels.withNewState( - localModel.copy(downloadState = downloadState), - ), - ) - } + updateState { state -> + state.withUpdatedLocalModel( + localModel().copy( + downloadState = downloadState, + downloaded = downloadState is DownloadState.Complete + ), + ) } }, ) - .also { downloadDisposables.add(localModel.id to it) } + .also { downloadDisposables.add(localModel().id to it) } } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt index 7a8b65ce..4408ad5d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt @@ -93,6 +93,7 @@ fun LocalDiffusionForm( is DownloadState.Downloading -> Icons.Outlined.FileDownload else -> when { model.id == LocalAiModel.CustomOnnx.id -> Icons.Outlined.Landslide + model.id == LocalAiModel.CustomMediaPipe.id -> Icons.Outlined.Landslide model.downloaded -> Icons.Outlined.FileDownloadDone else -> Icons.Outlined.FileDownloadOff } @@ -114,8 +115,11 @@ fun LocalDiffusionForm( overflow = TextOverflow.Ellipsis, maxLines = 2 ) - if (model.id != LocalAiModel.CustomOnnx.id) { - Text( + when (model.id) { + LocalAiModel.CustomOnnx.id, + LocalAiModel.CustomMediaPipe.id -> Unit + + else -> Text( text = model.size, maxLines = 1 ) @@ -156,81 +160,83 @@ fun LocalDiffusionForm( text = stringResource(id = LocalizationR.string.model_local_custom_title), style = MaterialTheme.typography.bodyMedium, ) - Spacer(modifier = Modifier.height(4.dp)) - Text( - text = stringResource(id = LocalizationR.string.model_local_custom_sub_title), - style = MaterialTheme.typography.bodyMedium, - ) - Spacer(modifier = Modifier.height(4.dp)) + if (model.id == LocalAiModel.CustomOnnx.id) { + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = stringResource(id = LocalizationR.string.model_local_custom_sub_title), + style = MaterialTheme.typography.bodyMedium, + ) + Spacer(modifier = Modifier.height(4.dp)) - fun folderModifier(treeNum: Int) = - Modifier.padding(start = (treeNum - 1) * 12.dp) + fun folderModifier(treeNum: Int) = + Modifier.padding(start = (treeNum - 1) * 12.dp) - val folderStyle = MaterialTheme.typography.bodySmall - Text( - modifier = Modifier.padding(start = 12.dp), - text = state.localOnnxCustomModelPath, - style = folderStyle, - ) + val folderStyle = MaterialTheme.typography.bodySmall + Text( + modifier = Modifier.padding(start = 12.dp), + text = state.localOnnxCustomModelPath, + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "text_encoder", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "model.ort", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "text_encoder", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "model.ort", + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "tokenizer", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "merges.txt", - style = folderStyle, - ) - Text( - modifier = folderModifier(3), - text = "special_tokens_map.json", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "tokenizer_config.json", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "vocab.json", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "tokenizer", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "merges.txt", + style = folderStyle, + ) + Text( + modifier = folderModifier(3), + text = "special_tokens_map.json", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "tokenizer_config.json", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "vocab.json", + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "unet", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "model.ort", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "unet", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "model.ort", + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "vae_decoder", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "model.ort", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "vae_decoder", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "model.ort", + style = folderStyle, + ) + } } } when (model.downloadState) { @@ -359,10 +365,13 @@ fun LocalDiffusionForm( modifier = Modifier .fillMaxWidth() .padding(top = 14.dp), - value = state.localOnnxCustomModelPath, - onValueChange = { processIntent(ServerSetupIntent.SelectLocalModelPath(it)) }, + value = state.localCustomModelPath, + onValueChange = { string -> + string.filter { it != '\n' } + .let(ServerSetupIntent::SelectLocalModelPath) + .let(processIntent::invoke) + }, enabled = true, - singleLine = true, label = { Text(stringResource(LocalizationR.string.model_local_path_title)) }, trailingIcon = { IconButton( @@ -405,7 +414,8 @@ fun LocalDiffusionForm( } state.localModels .filter { - val customPredicate = it.id == LocalAiModel.CustomOnnx.id || it.id == LocalAiModel.CustomMediaPipe.id + val customPredicate = + it.id == LocalAiModel.CustomOnnx.id || it.id == LocalAiModel.CustomMediaPipe.id if (state.localCustomModel) customPredicate else !customPredicate } .forEach { localModel -> modelItemUi(localModel) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index fc30f692..52f063ed 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.notification.PushNotificationManager import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager @@ -66,11 +67,13 @@ class TextToImageViewModel( ) { private val progressModal: Modal - get() { - if (currentState.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - return Modal.Generating(canCancel = preferenceManager.localOnnxAllowCancel) + get() = when (currentState.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { + Modal.Generating(canCancel = preferenceManager.localOnnxAllowCancel) } - return Modal.Communicating() + + else -> Modal.Communicating() } override val initialState = TextToImageState() @@ -135,7 +138,7 @@ class TextToImageViewModel( ?.let(::setActiveModal) } - override fun onReceivedLocalDiffusionStatus(status: LocalDiffusion.Status) { + override fun onReceivedLocalDiffusionStatus(status: LocalDiffusionStatus) { (currentState.screenModal as? Modal.Generating) ?.copy(status = status) ?.let(::setActiveModal) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt index e6c332c2..aa25ae1c 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt @@ -4,8 +4,8 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.notification.PushNotificationManager import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.Settings -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor @@ -57,7 +57,7 @@ abstract class CoreGenerationMviViewModelTest() private val stubHordeProcessStatus = BehaviorSubject.create() - private val stubLdStatus = BehaviorSubject.create() + private val stubLdStatus = BehaviorSubject.create() protected val stubCustomSchedulers = object : SchedulersProvider { diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index 57d73860..f6ae1a88 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -157,6 +157,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { stubWakeLockInterActor.releaseWakeLockUseCase() } returns Result.success(Unit) + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) val localModel = mockServerSetupStateLocalModel.copy( downloadState = DownloadState.Unknown, ) @@ -165,7 +166,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { val state = viewModel.state.value val expected = true - val actual = state.localOnnxModels.any { + val actual = state.localModels.any { it.downloadState == DownloadState.Downloading(22) } Assert.assertEquals(expected, actual) From 684cefa7ad67be98f1477fa53af5c2e2ca74ee00 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Wed, 28 Aug 2024 09:44:51 +0300 Subject: [PATCH 09/10] Fix source selection list auto non-visible scroll --- .../screen/setup/steps/SourceSelectionStep.kt | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt index 54c9f1b0..f7779458 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt @@ -5,16 +5,23 @@ import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.LazyListItemInfo import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableIntStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier +import androidx.compose.ui.layout.onSizeChanged import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState import com.shifthackz.aisdv1.presentation.screen.setup.components.ConfigurationModeButton +import kotlin.math.abs @Composable fun SourceSelectionStep( @@ -23,12 +30,26 @@ fun SourceSelectionStep( processIntent: (ServerSetupIntent) -> Unit = {}, ) { val lazyListState = rememberLazyListState() + var lazyListHeight by remember { mutableIntStateOf(0) } + var lazyListItemHeight by remember { mutableIntStateOf(0) } + LaunchedEffect(state.mode) { // Adding 1 here, because item with index == 0 is top spacer - lazyListState.animateScrollToItem(state.mode.ordinal + 1) + val newIndex = state.mode.ordinal +1 + val visibleIndexes = lazyListState.layoutInfo.visibleItemsInfo + .filter { it.offset >= 0 } + .filter { + if (lazyListHeight == 0 || lazyListItemHeight == 0) true + else abs(lazyListHeight - it.offset) >= lazyListItemHeight + } + .map(LazyListItemInfo::index) + + if (!visibleIndexes.contains(newIndex)) lazyListState.animateScrollToItem(newIndex) } + LazyColumn( - modifier = modifier, + modifier = modifier + .onSizeChanged { lazyListHeight = it.height }, state = lazyListState, ) { item(key = "SPACER_TOP") { Spacer(modifier = Modifier.height(12.dp)) } @@ -39,7 +60,8 @@ fun SourceSelectionStep( ConfigurationModeButton( modifier = Modifier .fillMaxWidth() - .padding(horizontal = 16.dp, vertical = 4.dp), + .padding(horizontal = 16.dp, vertical = 4.dp) + .onSizeChanged { lazyListItemHeight = it.height }, state = state, mode = mode, onClick = { From d452e7c8a5c89ddfe7119401902ffd6857d62199 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sun, 8 Sep 2024 21:29:10 +0300 Subject: [PATCH 10/10] Fix deprecations --- .../aisdv1/presentation/screen/gallery/list/GalleryScreen.kt | 2 +- .../screen/inpaint/components/InPaintComponent.kt | 2 +- .../aisdv1/presentation/widget/input/SliderTextInputField.kt | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt index a6052000..67572859 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt @@ -468,7 +468,7 @@ fun GalleryScreenContent( val selected = state.selection.contains(item.id) GalleryUiItem( modifier = Modifier - .animateItemPlacement(tween(500)) + .animateItem(tween(500)) .shake( enabled = state.selectionMode && !selected, animationDurationMillis = 188, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt index 47e94240..98c45cbd 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt @@ -133,7 +133,7 @@ fun InPaintComponent( } MotionEvent.Move -> { - currentPath.quadraticBezierTo( + currentPath.quadraticTo( previousPosition.x, previousPosition.y, (previousPosition.x + currentPosition.x) / 2, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt index 4175e70f..8e3964c9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt @@ -94,9 +94,9 @@ fun SliderTextInputField( enabled = true, singleLine = true, keyboardOptions = KeyboardOptions( + autoCorrectEnabled = false, keyboardType = KeyboardType.Number, - autoCorrect = false, - imeAction = ImeAction.Done, + imeAction = ImeAction.Done ), label = { Text(stringResource(id = R.string.hint_value)) }, trailingIcon = {