Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions src/commonMain/kotlin/io/github/koalaplot/core/heatmap/ColorScales.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.github.koalaplot.core.heatmap

import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.lerp
import io.github.koalaplot.core.util.lerp
import io.github.koalaplot.core.util.normalize

public typealias ColorScale<Z> = (Z) -> Color

/**
* Creates a linear color scale that interpolates between colors.
* @param domain Range of values to map
* @param colors List of colors to interpolate between
*/
public fun <Z> linearColorScale(
domain: ClosedRange<Z>,
colors: List<Color>,
): ColorScale<Z> where Z : Comparable<Z>, Z : Number = { value ->
val normalized = domain.normalize(value).toFloat().coerceIn(0f, 1f)

if (colors.size == 1) {
colors[0]
} else {
val segmentSize = 1f / (colors.size - 1)
val segmentIndex = (normalized / segmentSize).toInt().coerceAtMost(colors.size - 2)
val segmentProgress = (normalized - segmentIndex * segmentSize) / segmentSize

lerp(colors[segmentIndex], colors[segmentIndex + 1], segmentProgress)
}
}

/**
* Creates a diverging color scale with a neutral midpoint.
* @param domain Range of values to map
* @param lowColor Color for low values
* @param midColor Color for midpoint values
* @param highColor Color for high values
*/
@Suppress("MagicNumber")
public fun <Z> divergingColorScale(
domain: ClosedRange<Z>,
lowColor: Color = Color.Blue,
midColor: Color = Color.White,
highColor: Color = Color.Red,
): ColorScale<Z> where Z : Comparable<Z>, Z : Number = { value ->
val normalized = domain.normalize(value).toFloat().coerceIn(0f, 1f)

if (normalized < 0.5f) {
val progress = normalized * 2f
lerp(lowColor, midColor, progress)
} else {
val progress = (normalized - 0.5f) * 2f
lerp(midColor, highColor, progress)
}
}

/**
* Creates a discrete color scale that maps values to specific colors.
* @param thresholds List of threshold values (ascending)
* @param colors List of colors (same length as thresholds + 1)
*/
public fun <Z> discreteColorScale(
thresholds: List<Z>,
colors: List<Color>,
): ColorScale<Z> where Z : Comparable<Z> {
require(colors.size == thresholds.size + 1) {
"There should be one more color (now ${colors.size}) " +
"than thresholds (${thresholds.size})"
}
return { value ->
val index = thresholds.indexOfFirst { it > value }
if (index < 0) {
colors.last()
} else {
colors[index]
}
}
}

/**
* Creates a discrete color scale with automatic binning.
* @param domain Range of values to map
* @param binCount Number of discrete bins
* @param colors List of colors for each bin
*/
public fun <Z> discreteColorScale(
domain: ClosedRange<Z>,
colors: List<Color>,
): ColorScale<Z> where Z : Comparable<Z>, Z : Number {
require(colors.size >= 1) { "Scale needs at least one color" }
val binCount = colors.size

val thresholds = (1 until binCount).map { i ->
val normalized = (0..binCount).normalize(i)
domain.lerp(normalized)
}

return discreteColorScale(
thresholds,
colors,
)
}
107 changes: 107 additions & 0 deletions src/commonMain/kotlin/io/github/koalaplot/core/heatmap/HeatMapPlot.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package io.github.koalaplot.core.heatmap

import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.AnimationSpec
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.drawscope.DrawScope
import io.github.koalaplot.core.animation.StartAnimationUseCase
import io.github.koalaplot.core.style.KoalaPlotTheme
import io.github.koalaplot.core.xygraph.Point
import io.github.koalaplot.core.xygraph.XYGraphScope
import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min

public typealias HeatMapGrid<Z> = Array<Array<Z>>

/**
* A HeatMap plot displays 2-dimensional data values as color.
*
* @param xDomain Domain for the x dimension.
* @param yDomain Domain for the y dimension.
* @param bins An 2D array of values.
* @param colorScale A mapping function from value to color.
* @param alphaScale A mapping function from value to alpha.
* @param animationSpec The [AnimationSpec] to use for animating the plot.
*/
@Composable
public fun <X : Comparable<X>, Y : Comparable<Y>, Z> XYGraphScope<X, Y>.HeatMapPlot(
Comment thread
gsteckman marked this conversation as resolved.
xDomain: ClosedRange<X>,
yDomain: ClosedRange<Y>,
bins: HeatMapGrid<Z>,
colorScale: (Z) -> Color,
alphaScale: (Z) -> Float = { 1f },
animationSpec: AnimationSpec<Float> = KoalaPlotTheme.animationSpec,
) {
if (bins.isEmpty() || bins[0].isEmpty()) return

val beta = remember { Animatable(0f) }
LaunchedEffect(null) { beta.animateTo(1f, animationSpec = animationSpec) }

val xBins = bins.size
val yBins = bins[0].size

Canvas(modifier = Modifier.fillMaxSize()) {
fun mapX(x: X): Float = xAxisModel.computeOffset(x) * size.width

fun mapY(y: Y): Float = yAxisModel.computeOffset(y) * size.height

fun <T : Comparable<T>> sortPair(
a: T,
b: T,
): Pair<T, T> = if (a <= b) a to b else b to a

val (left, right) = sortPair(
mapX(xDomain.start),
mapX(xDomain.endInclusive),
)
val (top, bottom) = sortPair(
mapY(yDomain.start),
mapY(yDomain.endInclusive),
)

// Pre-calculate cell size
val cellWidth = (right - left) / xBins
val cellHeight = (top - bottom) / yBins
val cellSize = Size(
beta.value * abs(cellWidth),
beta.value * abs(cellHeight),
)
val animationOffset = (1f - beta.value) / 2f

fun drawRect(
xi: Int,
yi: Int,
) {
val value = bins[xi][yi] ?: return
val alpha = alphaScale(value) * beta.value
if (alpha <= 0f) return
val cellColor = colorScale(value)
val cellLeft = left + (xi + animationOffset) * cellWidth
val cellTop = bottom + (yi + 1 + animationOffset) * cellHeight

drawRect(
color = cellColor,
topLeft = Offset(cellLeft, cellTop),
size = cellSize,
alpha = alpha,
)
}

for (xi in 0 until xBins) {
for (yi in 0 until yBins) {
drawRect(xi, yi)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.github.koalaplot.core.heatmap

import io.github.koalaplot.core.util.lerp
import io.github.koalaplot.core.util.normalize
import kotlin.math.floor

/**
* Generates a 2D histogram from a list of samples.
*
* Creates a 2D grid where each cell contains the count of samples that fall within that cell's boundaries.
* Samples outside the specified domains are ignored (clipped).
*
* @param T The type of data points being processed
* @param X The numeric type for x-coordinates
* @param Y The numeric type for y-coordinates
* @param samples List of data points to histogram
* @param nBinsX Number of bins along x-axis
* @param nBinsY Number of bins along y-axis
* @param xDomain Range of x-values to include in histogram
* @param yDomain Range of y-values to include in histogram
* @param xGetter Function to extract x-coordinate from a sample
* @param yGetter Function to extract y-coordinate from a sample
* @return HeatMapGrid containing the histogram counts
*/
@Suppress("LoopWithTooManyJumpStatements")
public fun <T, X, Y> generateHistogram2D(
samples: List<T>,
xDomain: ClosedRange<X>,
yDomain: ClosedRange<Y>,
xGetter: (T) -> X,
yGetter: (T) -> Y,
nBinsX: Int = 100,
nBinsY: Int = 100,
): HeatMapGrid<Int> where X : Comparable<X>, X : Number, Y : Comparable<Y>, Y : Number {
require(nBinsX > 0 && nBinsY > 0) { "Number of bins must be positive." }

val bins = Array(nBinsX) { Array<Int>(nBinsY) { 0 } }
for (sample in samples) {
val ix = (0..nBinsX).lerp(xDomain.normalize(xGetter(sample)))
val iy = (0..nBinsY).lerp(yDomain.normalize(yGetter(sample)))

if (ix !in 0 until nBinsX) continue
if (iy !in 0 until nBinsY) continue

bins[ix][iy]++
}
return bins
}
69 changes: 69 additions & 0 deletions src/commonMain/kotlin/io/github/koalaplot/core/util/Range.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package io.github.koalaplot.core.util

@Suppress("UNCHECKED_CAST")
private fun <Z : Number> doubleToTypeOf(
value: Double,
example: Z,
): Z = when (example) {
is Double -> {
value as Z
}

is Float -> {
value.toFloat() as Z
}

is Int -> {
kotlin.math.floor(value).toInt() as Z
}

is Long -> {
kotlin.math.floor(value).toLong() as Z
}

is Short -> {
kotlin.math
.floor(value)
.toInt()
.toShort() as Z
}

is Byte -> {
kotlin.math
.floor(value)
.toInt()
.toByte() as Z
}

else -> {
throw UnsupportedOperationException("Unsupported numeric type: ${example::class}")
}
}

/**
* Linearly normalizes the value within the range between 0.0 and 1.0.
* Values outside the range are extrapolated.
* When extremes of the range are equal, always returns zero.
* This is the inverse of operation [lerp].
*/
public fun <T> ClosedRange<T>.normalize(value: T): Double
where T : Number, T : Comparable<T> {
val r0 = start.toDouble()
val r1 = endInclusive.toDouble()
val size = r1 - r0
if (size == 0.0) return 0.0
return (value.toDouble() - r0) / size
}

/**
* Linearly interpolates within the range by the factor t.
* For t values beyond 0.0..1.0, linear extrapolation is done.
* This is the inverse of operation [normalize].
*/
public fun <T> ClosedRange<T>.lerp(t: Double): T
where T : Number, T : Comparable<T> {
val r0 = start.toDouble()
val r1 = endInclusive.toDouble()
val size = r1 - r0
return doubleToTypeOf(t * size + r0, start)
}
Loading