Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ dependencies {
implementation(libs.phrase)
implementation(libs.copper.flow)
implementation(libs.kotlinx.coroutines.android)
implementation(libs.kotlinx.coroutines.guava)
implementation(libs.kovenant)
implementation(libs.kovenant.android)
implementation(libs.opencsv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,6 @@ class RemoteFileDownloadWorker @AssistedInject constructor(
return File(downloadsDirectory(context), remote.sha256Hash())
}

fun cancelAll(context: Context) {
WorkManager.getInstance(context).cancelAllWorkByTag(TAG)
}

private fun uniqueWorkName(remote: RemoteFile): String {
return "download-remote-file-${remote.sha256Hash()}"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.thoughtcrime.securesms.notifications

import android.content.Context
import androidx.work.WorkInfo
import androidx.work.WorkManager
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
Expand All @@ -10,8 +12,9 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.debounce
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.flow.scan
import kotlinx.coroutines.guava.await
import kotlinx.coroutines.launch
import kotlinx.coroutines.supervisorScope
import org.session.libsession.database.userAuth
Expand All @@ -25,28 +28,28 @@ import org.thoughtcrime.securesms.database.Storage
import org.thoughtcrime.securesms.dependencies.ConfigFactory
import org.thoughtcrime.securesms.dependencies.ManagerScope
import org.thoughtcrime.securesms.dependencies.OnAppStartupComponent
import java.security.MessageDigest
import javax.inject.Inject
import javax.inject.Singleton

private const val TAG = "PushRegistrationHandler"

/**
* A class that listens to the config, user's preference, token changes and
* register/unregister push notification accordingly.
* PN registration source of truth using per-account periodic workers.
*
* Periodic workers must be created with tags:
* - "pn-register-periodic"
* - "pn-acc-<hexAccountId>"
* - "pn-tfp-<tokenFingerprint>"
*
* This class DOES NOT handle the legacy groups push notification.
*/
@Singleton
class PushRegistrationHandler
@Inject
constructor(
class PushRegistrationHandler @Inject constructor(
private val configFactory: ConfigFactory,
private val preferences: TextSecurePreferences,
private val tokenFetcher: TokenFetcher,
@param:ApplicationContext private val context: Context,
@ApplicationContext private val context: Context,
private val registry: PushRegistryV2,
private val storage: Storage,
@param:ManagerScope private val scope: CoroutineScope
@ManagerScope private val scope: CoroutineScope
) : OnAppStartupComponent {

private var job: Job? = null
Expand All @@ -62,83 +65,171 @@ constructor(
.onStart { emit(Unit) },
preferences.watchLocalNumber(),
preferences.pushEnabled,
tokenFetcher.token,
) { _, myAccountId, enabled, token ->
if (!enabled || myAccountId == null || storage.getUserED25519KeyPair() == null || token.isNullOrEmpty()) {
return@combine emptySet<SubscriptionKey>()
tokenFetcher.token
) { _, _, enabled, token ->
val desired =
if (enabled && hasCoreIdentity())
desiredSubscriptions()
else emptySet()
Triple(enabled, token, desired)
}
.distinctUntilChanged()
.collect { (pushEnabled, token, desiredIds) ->
try {
reconcileWithWorkManager(pushEnabled, token, desiredIds)
} catch (t: Throwable) {
Log.e(TAG, "Reconciliation failed", t)
}
}
}
}

private suspend fun reconcileWithWorkManager(
pushEnabled: Boolean,
token: String?,
activeAccounts: Set<AccountId>
) {
val wm = WorkManager.getInstance(context)

// Read existing push periodic workers and parse (AccountId, tokenFingerprint) from tags.
val periodicInfos = wm.getWorkInfosByTag(TAG_PERIODIC).await()
.filter { it.state != WorkInfo.State.CANCELLED && it.state != WorkInfo.State.FAILED }

setOf(SubscriptionKey(AccountId(myAccountId), token)) + getGroupSubscriptions(token)
Log.d(TAG, "We currently have ${periodicInfos.size} push periodic workers")

val accountsAlreadyRegistered: Map<AccountId, String> = buildMap {
for (info in periodicInfos) {
val id = parseAccountId(info) ?: continue
val token = parseTokenFingerprint(info) ?: continue
put(id, token)
}
.scan(emptySet<SubscriptionKey>() to emptySet<SubscriptionKey>()) { acc, current ->
acc.second to current
}
.collect { (prev, current) ->
val added = current - prev
val removed = prev - current
if (added.isNotEmpty()) {
Log.d(TAG, "Adding ${added.size} new subscriptions")
}
}

if (removed.isNotEmpty()) {
Log.d(TAG, "Removing ${removed.size} subscriptions")
// If push disabled or identity missing → cancel all and try to deregister.
if (!pushEnabled || !hasCoreIdentity()) {
val toCancel = accountsAlreadyRegistered.keys
if (toCancel.isNotEmpty()) {
Log.d(TAG, "Push disabled/identity missing; cancelling ${toCancel.size} PN periodic works")
}
supervisorScope {
toCancel.forEach { id ->
launch {
PushRegistrationWorker.cancelAll(context, id)
tryUnregister(token, id)
}
}
}
return
}

for (key in added) {
PushRegistrationWorker.schedule(
context = context,
token = key.token,
accountId = key.accountId,
)
}
val currentFingerprint = token?.let { tokenFingerprint(it) }

supervisorScope {
for (key in removed) {
PushRegistrationWorker.cancelRegistration(
context = context,
accountId = key.accountId,
)

launch {
Log.d(TAG, "Unregistering push token for account: ${key.accountId}")
try {
val swarmAuth = swarmAuthForAccount(key.accountId)
?: throw IllegalStateException("No SwarmAuth found for account: ${key.accountId}")

registry.unregister(
token = key.token,
swarmAuth = swarmAuth,
)

Log.d(TAG, "Successfully unregistered push token for account: ${key.accountId}")
} catch (e: Exception) {
if (e !is CancellationException) {
Log.e(TAG, "Failed to unregister push token for account: ${key.accountId}", e)
}
}
}
}
}
// Add missing (ensure periodic + run now) — only if we have a token.
val accountsToAdd = activeAccounts - accountsAlreadyRegistered.keys
if (accountsToAdd.isNotEmpty()) Log.d(TAG, "Adding ${accountsToAdd.size} PN registrations")
if (!token.isNullOrEmpty()) {
accountsToAdd.forEach { id ->
PushRegistrationWorker.ensurePeriodic(context, id, token, replace = false) // KEEP
PushRegistrationWorker.scheduleImmediate(context, id, token) // run now
}
}

// Token rotation: replace periodic where fingerprint mismatches.
if (!token.isNullOrEmpty()) {
var replaced = 0
activeAccounts.forEach { id ->
val tokenFingerprint = accountsAlreadyRegistered[id] ?: return@forEach
if (tokenFingerprint != currentFingerprint) {
PushRegistrationWorker.ensurePeriodic(context, id, token, replace = true) // REPLACE
PushRegistrationWorker.scheduleImmediate(context, id, token)
replaced++
}
}
if (replaced > 0) Log.d(TAG, "Replaced $replaced periodic PN workers due to token rotation")
}

// Removed subscriptions: cancel workers & attempt deregister.
val accountToRemove = accountsAlreadyRegistered.keys - activeAccounts
if (accountToRemove.isNotEmpty()) Log.d(TAG, "Removing ${accountToRemove.size} PN registrations")
supervisorScope {
accountToRemove.forEach { id ->
launch {
PushRegistrationWorker.cancelAll(context, id)
tryUnregister(token, id)
}
}
}
}

/**
* Build desired subscriptions: self (local number) + any group that shouldPoll.
* */
private fun desiredSubscriptions(): Set<AccountId> = buildSet {
preferences.getLocalNumber()?.let { add(AccountId(it)) }
val groups = configFactory.withUserConfigs { it.userGroups.allClosedGroupInfo() }
groups.filter { it.shouldPoll }
.mapTo(this) { AccountId(it.groupAccountId) }
}

private fun hasCoreIdentity(): Boolean {
return preferences.getLocalNumber() != null && storage.getUserED25519KeyPair() != null
}

/**
* Try to deregister if we still have credentials and a token to sign with.
* Safe to no-op if token/auth missing (e.g., keys already deleted).
*/
private suspend fun tryUnregister(token: String?, accountId: AccountId) {
if (token.isNullOrEmpty()) return
val auth = swarmAuthForAccount(accountId) ?: return
try {
Log.d(TAG, "Unregistering PN for $accountId")
registry.unregister(token = token, swarmAuth = auth)
Log.d(TAG, "Unregistered PN for $accountId")
} catch (e: Exception) {
if (e !is CancellationException) {
Log.e(TAG, "Unregister failed for $accountId", e)
} else {
throw e
}
}
}

private fun swarmAuthForAccount(accountId: AccountId): SwarmAuth? {
return when (accountId.prefix) {
IdPrefix.STANDARD -> storage.userAuth?.takeIf { it.accountId == accountId }
IdPrefix.GROUP -> configFactory.getGroupAuth(accountId)
else -> null // Unsupported account ID prefix
IdPrefix.GROUP -> configFactory.getGroupAuth(accountId)
else -> null
}
}

private fun getGroupSubscriptions(
token: String
): Set<SubscriptionKey> {
return configFactory.withUserConfigs { it.userGroups.allClosedGroupInfo() }
.asSequence()
.filter { it.shouldPoll }
.mapTo(hashSetOf()) { SubscriptionKey(accountId = AccountId(it.groupAccountId), token = token) }
private fun parseAccountId(info: WorkInfo): AccountId? {
val tag = info.tags.firstOrNull { it.startsWith(ARG_ACCOUNT_ID) } ?: return null
val hex = tag.removePrefix(ARG_ACCOUNT_ID)
return AccountId.fromStringOrNull(hex)
}

private data class SubscriptionKey(val accountId: AccountId, val token: String)
}
private fun parseTokenFingerprint(info: WorkInfo): String? {
val tag = info.tags.firstOrNull { it.startsWith(ARG_TOKEN) } ?: return null
return tag.removePrefix(ARG_TOKEN)
}

companion object {
private const val TAG = "PushRegistrationHandler"

const val TAG_PERIODIC = "pn-register-periodic"
const val ARG_ACCOUNT_ID = "pn-account-"
const val ARG_TOKEN = "pn-token-"

fun tokenFingerprint(token: String): String {
val digest = MessageDigest.getInstance("SHA-256")
.digest(token.toByteArray(Charsets.UTF_8))
val short = digest.copyOfRange(0, 8) // 64 bits is plenty for equality checks
@Suppress("InlinedApi")
return android.util.Base64.encodeToString(
short,
android.util.Base64.NO_WRAP or android.util.Base64.URL_SAFE
)
}
}
}
Loading