Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ dataframe_benchmark/
bin/
coverage-html
.DS_Store
flake.lock
tags
__pycache__
venv
Expand All @@ -45,4 +44,4 @@ Cargo.lock
# (transient; the committed *.db fixtures themselves stay tracked).
*.db-wal
*.db-shm
*.db-journal
*.db-journal
111 changes: 82 additions & 29 deletions dataframe-learn/src/DataFrame/DecisionTree/Cart.hs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All dataframe-learn changes are from fourmolu

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest just removing those changes from the PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint doesn't pass without them! I think "passes its own lint steps" is an element of project config consistency. It's in a separate commit though so easy to separate 🤷🏻‍♂️

Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- | sklearn-faithful CART initializer used to seed TAO. One-hot encodes
-- categoricals, splits on exact (unsmoothed) Gini over midpoint thresholds
-- (@<=@ routes left), and emits a @Tree@ predicting identically to
-- @DecisionTreeClassifier(criterion='gini')@ on continuous features.
{- | sklearn-faithful CART initializer used to seed TAO. One-hot encodes
categoricals, splits on exact (unsmoothed) Gini over midpoint thresholds
(@<=@ routes left), and emits a @Tree@ predicting identically to
@DecisionTreeClassifier(criterion='gini')@ on continuous features.
-}
module DataFrame.DecisionTree.Cart (
CartFeature (..),
CartNode (..),
Expand Down Expand Up @@ -39,8 +40,9 @@ import qualified Data.Vector.Algorithms.Merge as VA
import qualified Data.Vector.Unboxed as VU
import Type.Reflection (typeRep)

-- | A one-hot feature column: per-row Double values plus the sklearn LEFT
-- predicate (@x <= threshold@) over the ORIGINAL DataFrame.
{- | A one-hot feature column: per-row Double values plus the sklearn LEFT
predicate (@x <= threshold@) over the ORIGINAL DataFrame.
-}
data CartFeature = CartFeature
{ cfValues :: !(VU.Vector Double)
, cfPred :: !(Double -> Expr Bool)
Expand All @@ -59,16 +61,18 @@ data CartCtx = CartCtx
, ctxMinLeaf :: !Int
}

-- | Indices @0..n-1@ stably sorted by their value (ascending), ties keeping
-- ascending index. In-place unboxed merge sort — no boxed-list allocation.
{- | Indices @0..n-1@ stably sorted by their value (ascending), ties keeping
ascending index. In-place unboxed merge sort — no boxed-list allocation.
-}
sortIndicesByValue :: VU.Vector Double -> VU.Vector Int
sortIndicesByValue vs =
VU.create $ do
mv <- VU.thaw (VU.enumFromN 0 (VU.length vs))
VA.sortBy (compare `on` (vs VU.!)) mv
pure mv

buildCartTree :: forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> DataFrame -> Tree a
buildCartTree ::
forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> DataFrame -> Tree a
buildCartTree cfg target df =
cartToTree feats classes (buildCartNode ctx 0 (VU.enumFromN 0 nAll) featSorted)
where
Expand Down Expand Up @@ -109,21 +113,35 @@ cartToTree feats classes = go

classCounts :: CartCtx -> VU.Vector Int -> VU.Vector Int
classCounts ctx idxs =
VU.accumulate (+) (VU.replicate (ctxNClasses ctx) 0) (VU.map (\i -> (ctxCodes ctx VU.! i, 1)) idxs)
VU.accumulate
(+)
(VU.replicate (ctxNClasses ctx) 0)
(VU.map (\i -> (ctxCodes ctx VU.! i, 1)) idxs)

isPure :: VU.Vector Int -> Bool
isPure counts = VU.length (VU.filter (> 0) counts) <= 1

buildCartNode :: CartCtx -> Int -> VU.Vector Int -> V.Vector (VU.Vector Int) -> CartNode
buildCartNode ::
CartCtx -> Int -> VU.Vector Int -> V.Vector (VU.Vector Int) -> CartNode
buildCartNode ctx depth idxs sortedByFeat
| VU.length idxs < 2 || depth >= ctxMaxDepth ctx || isPure counts = leaf
| otherwise = maybe leaf (splitNode ctx depth idxs sortedByFeat) (bestSplit ctx sortedByFeat counts n)
| otherwise =
maybe
leaf
(splitNode ctx depth idxs sortedByFeat)
(bestSplit ctx sortedByFeat counts n)
where
n = VU.length idxs
counts = classCounts ctx idxs
leaf = CLeaf (VU.maxIndex counts)

splitNode :: CartCtx -> Int -> VU.Vector Int -> V.Vector (VU.Vector Int) -> (Int, Double) -> CartNode
splitNode ::
CartCtx ->
Int ->
VU.Vector Int ->
V.Vector (VU.Vector Int) ->
(Int, Double) ->
CartNode
splitNode ctx depth idxs sortedByFeat (fj, thr) =
CSplit fj thr (rec leftIdx leftSorted) (rec rightIdx rightSorted)
where
Expand All @@ -134,9 +152,15 @@ splitNode ctx depth idxs sortedByFeat (fj, thr) =
rightSorted = V.map (VU.filter (\i -> vals VU.! i > thr)) sortedByFeat
rec = buildCartNode ctx (depth + 1)

-- | Minimum weighted-child-Gini @(feature, threshold)@; the first feature wins
-- ties; 'Nothing' when no feature has a leaf-size-respecting threshold.
bestSplit :: CartCtx -> V.Vector (VU.Vector Int) -> VU.Vector Int -> Int -> Maybe (Int, Double)
{- | Minimum weighted-child-Gini @(feature, threshold)@; the first feature wins
ties; 'Nothing' when no feature has a leaf-size-respecting threshold.
-}
bestSplit ::
CartCtx ->
V.Vector (VU.Vector Int) ->
VU.Vector Int ->
Int ->
Maybe (Int, Double)
bestSplit ctx sortedByFeat counts n =
fmap (\(_, j, t) -> (j, t)) (foldl' consider Nothing [0 .. ctxNFeats ctx - 1])
where
Expand All @@ -145,18 +169,30 @@ bestSplit ctx sortedByFeat counts n =
Just (g, thr) | maybe True (\(gB, _, _) -> g < gB) acc -> Just (g, fj, thr)
_ -> acc

-- | Accumulator while sweeping a feature's sorted rows: best @(gini, thr)@ so
-- far, per-class left counts, rows moved left, and the previous value seen.
{- | Accumulator while sweeping a feature's sorted rows: best @(gini, thr)@ so
far, per-class left counts, rows moved left, and the previous value seen.
-}
data Sweep = Sweep
{ swBest :: !(Maybe (Double, Double))
, swLeft :: ![Int]
, swMoved :: !Int
, swPrev :: !Double
}

sweepFeature :: CartCtx -> [Int] -> VU.Vector Int -> CartFeature -> Int -> Maybe (Double, Double)
sweepFeature ::
CartCtx ->
[Int] ->
VU.Vector Int ->
CartFeature ->
Int ->
Maybe (Double, Double)
sweepFeature ctx total si feat n =
swBest (foldl' step (Sweep Nothing (replicate (ctxNClasses ctx) 0) 0 (0 / 0)) [0 .. VU.length si - 1])
swBest
( foldl'
step
(Sweep Nothing (replicate (ctxNClasses ctx) 0) 0 (0 / 0))
[0 .. VU.length si - 1]
)
where
vals = cfValues feat
step s k = advance ctx total n (vals VU.! i) (ctxCodes ctx VU.! i) s
Expand All @@ -165,24 +201,35 @@ sweepFeature ctx total si feat n =

advance :: CartCtx -> [Int] -> Int -> Double -> Int -> Sweep -> Sweep
advance ctx total n v c s =
Sweep (considerThreshold ctx total n v s) (bumpClass c (swLeft s)) (swMoved s + 1) v
Sweep
(considerThreshold ctx total n v s)
(bumpClass c (swLeft s))
(swMoved s + 1)
v

considerThreshold :: CartCtx -> [Int] -> Int -> Double -> Sweep -> Maybe (Double, Double)
considerThreshold ::
CartCtx -> [Int] -> Int -> Double -> Sweep -> Maybe (Double, Double)
considerThreshold ctx total n v s
| swMoved s >= ctxMinLeaf ctx
, n - swMoved s >= ctxMinLeaf ctx
, v > swPrev s + 1e-7 =
keepBetter (swBest s) (weightedGini total (swLeft s) (swMoved s) n) ((swPrev s + v) / 2)
keepBetter
(swBest s)
(weightedGini total (swLeft s) (swMoved s) n)
((swPrev s + v) / 2)
| otherwise = swBest s

keepBetter :: Maybe (Double, Double) -> Double -> Double -> Maybe (Double, Double)
keepBetter ::
Maybe (Double, Double) -> Double -> Double -> Maybe (Double, Double)
keepBetter best g thr = case best of
Just (wb, _) | wb <= g -> best
_ -> Just (g, thr)

weightedGini :: [Int] -> [Int] -> Int -> Int -> Double
weightedGini total leftAcc nl n =
(fromIntegral nl * giniImpurity leftAcc nl + fromIntegral nr * giniImpurity rightAcc nr)
( fromIntegral nl * giniImpurity leftAcc nl
+ fromIntegral nr * giniImpurity rightAcc nr
)
/ fromIntegral n
where
nr = n - nl
Expand All @@ -205,21 +252,27 @@ featuresOfColumn df c = case unsafeGetColumn c df of
UnboxedColumn _ (v :: VU.Vector b) -> numericFeature @b c v
BoxedColumn _ (v :: V.Vector b) -> oneHotFeatures @b (nRows df) c v

numericFeature :: forall b. (Columnable b, VU.Unbox b) => T.Text -> VU.Vector b -> [CartFeature]
numericFeature ::
forall b. (Columnable b, VU.Unbox b) => T.Text -> VU.Vector b -> [CartFeature]
numericFeature c v = case testEquality (typeRep @b) (typeRep @Double) of
Just Refl -> [CartFeature v (\t -> F.col @Double c .<=. F.lit t)]
Nothing -> case sIntegral @b of
STrue -> [CartFeature (VU.map fromIntegral v) (\t -> F.toDouble (F.col @b c) .<=. F.lit t)]
STrue ->
[ CartFeature (VU.map fromIntegral v) (\t -> F.toDouble (F.col @b c) .<=. F.lit t)
]
SFalse -> []

oneHotFeatures :: forall b. (Columnable b) => Int -> T.Text -> V.Vector b -> [CartFeature]
oneHotFeatures ::
forall b. (Columnable b) => Int -> T.Text -> V.Vector b -> [CartFeature]
oneHotFeatures nAll c v = case testEquality (typeRep @b) (typeRep @T.Text) of
Just Refl -> [oneHot nAll c v cat | cat <- Set.toList (Set.fromList (V.toList v))]
Nothing -> []

oneHot :: Int -> T.Text -> V.Vector T.Text -> T.Text -> CartFeature
oneHot nAll c v cat =
CartFeature (VU.generate nAll (\i -> if v V.! i == cat then 1 else 0)) (const (F.col @T.Text c ./=. F.lit cat))
CartFeature
(VU.generate nAll (\i -> if v V.! i == cat then 1 else 0))
(const (F.col @T.Text c ./=. F.lit cat))

-- | Target column as string labels (matches pandas @y.astype(str)@).
cartTargetLabels :: T.Text -> DataFrame -> V.Vector T.Text
Expand Down
Loading
Loading