diff --git a/.gitignore b/.gitignore index 607f2e4..ada5d0b 100644 --- a/.gitignore +++ b/.gitignore @@ -28,7 +28,6 @@ dataframe_benchmark/ bin/ coverage-html .DS_Store -flake.lock tags __pycache__ venv @@ -45,4 +44,4 @@ Cargo.lock # (transient; the committed *.db fixtures themselves stay tracked). *.db-wal *.db-shm -*.db-journal \ No newline at end of file +*.db-journal diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Cart.hs b/dataframe-learn/src/DataFrame/DecisionTree/Cart.hs index fb4188a..5337f5c 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Cart.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Cart.hs @@ -4,10 +4,11 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | sklearn-faithful CART initializer used to seed TAO. One-hot encodes --- categoricals, splits on exact (unsmoothed) Gini over midpoint thresholds --- (@<=@ routes left), and emits a @Tree@ predicting identically to --- @DecisionTreeClassifier(criterion='gini')@ on continuous features. +{- | sklearn-faithful CART initializer used to seed TAO. One-hot encodes +categoricals, splits on exact (unsmoothed) Gini over midpoint thresholds +(@<=@ routes left), and emits a @Tree@ predicting identically to +@DecisionTreeClassifier(criterion='gini')@ on continuous features. +-} module DataFrame.DecisionTree.Cart ( CartFeature (..), CartNode (..), @@ -39,8 +40,9 @@ import qualified Data.Vector.Algorithms.Merge as VA import qualified Data.Vector.Unboxed as VU import Type.Reflection (typeRep) --- | A one-hot feature column: per-row Double values plus the sklearn LEFT --- predicate (@x <= threshold@) over the ORIGINAL DataFrame. +{- | A one-hot feature column: per-row Double values plus the sklearn LEFT +predicate (@x <= threshold@) over the ORIGINAL DataFrame. +-} data CartFeature = CartFeature { cfValues :: !(VU.Vector Double) , cfPred :: !(Double -> Expr Bool) @@ -59,8 +61,9 @@ data CartCtx = CartCtx , ctxMinLeaf :: !Int } --- | Indices @0..n-1@ stably sorted by their value (ascending), ties keeping --- ascending index. In-place unboxed merge sort — no boxed-list allocation. +{- | Indices @0..n-1@ stably sorted by their value (ascending), ties keeping +ascending index. In-place unboxed merge sort — no boxed-list allocation. +-} sortIndicesByValue :: VU.Vector Double -> VU.Vector Int sortIndicesByValue vs = VU.create $ do @@ -68,7 +71,8 @@ sortIndicesByValue vs = VA.sortBy (compare `on` (vs VU.!)) mv pure mv -buildCartTree :: forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> DataFrame -> Tree a +buildCartTree :: + forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> DataFrame -> Tree a buildCartTree cfg target df = cartToTree feats classes (buildCartNode ctx 0 (VU.enumFromN 0 nAll) featSorted) where @@ -109,21 +113,35 @@ cartToTree feats classes = go classCounts :: CartCtx -> VU.Vector Int -> VU.Vector Int classCounts ctx idxs = - VU.accumulate (+) (VU.replicate (ctxNClasses ctx) 0) (VU.map (\i -> (ctxCodes ctx VU.! i, 1)) idxs) + VU.accumulate + (+) + (VU.replicate (ctxNClasses ctx) 0) + (VU.map (\i -> (ctxCodes ctx VU.! i, 1)) idxs) isPure :: VU.Vector Int -> Bool isPure counts = VU.length (VU.filter (> 0) counts) <= 1 -buildCartNode :: CartCtx -> Int -> VU.Vector Int -> V.Vector (VU.Vector Int) -> CartNode +buildCartNode :: + CartCtx -> Int -> VU.Vector Int -> V.Vector (VU.Vector Int) -> CartNode buildCartNode ctx depth idxs sortedByFeat | VU.length idxs < 2 || depth >= ctxMaxDepth ctx || isPure counts = leaf - | otherwise = maybe leaf (splitNode ctx depth idxs sortedByFeat) (bestSplit ctx sortedByFeat counts n) + | otherwise = + maybe + leaf + (splitNode ctx depth idxs sortedByFeat) + (bestSplit ctx sortedByFeat counts n) where n = VU.length idxs counts = classCounts ctx idxs leaf = CLeaf (VU.maxIndex counts) -splitNode :: CartCtx -> Int -> VU.Vector Int -> V.Vector (VU.Vector Int) -> (Int, Double) -> CartNode +splitNode :: + CartCtx -> + Int -> + VU.Vector Int -> + V.Vector (VU.Vector Int) -> + (Int, Double) -> + CartNode splitNode ctx depth idxs sortedByFeat (fj, thr) = CSplit fj thr (rec leftIdx leftSorted) (rec rightIdx rightSorted) where @@ -134,9 +152,15 @@ splitNode ctx depth idxs sortedByFeat (fj, thr) = rightSorted = V.map (VU.filter (\i -> vals VU.! i > thr)) sortedByFeat rec = buildCartNode ctx (depth + 1) --- | Minimum weighted-child-Gini @(feature, threshold)@; the first feature wins --- ties; 'Nothing' when no feature has a leaf-size-respecting threshold. -bestSplit :: CartCtx -> V.Vector (VU.Vector Int) -> VU.Vector Int -> Int -> Maybe (Int, Double) +{- | Minimum weighted-child-Gini @(feature, threshold)@; the first feature wins +ties; 'Nothing' when no feature has a leaf-size-respecting threshold. +-} +bestSplit :: + CartCtx -> + V.Vector (VU.Vector Int) -> + VU.Vector Int -> + Int -> + Maybe (Int, Double) bestSplit ctx sortedByFeat counts n = fmap (\(_, j, t) -> (j, t)) (foldl' consider Nothing [0 .. ctxNFeats ctx - 1]) where @@ -145,8 +169,9 @@ bestSplit ctx sortedByFeat counts n = Just (g, thr) | maybe True (\(gB, _, _) -> g < gB) acc -> Just (g, fj, thr) _ -> acc --- | Accumulator while sweeping a feature's sorted rows: best @(gini, thr)@ so --- far, per-class left counts, rows moved left, and the previous value seen. +{- | Accumulator while sweeping a feature's sorted rows: best @(gini, thr)@ so +far, per-class left counts, rows moved left, and the previous value seen. +-} data Sweep = Sweep { swBest :: !(Maybe (Double, Double)) , swLeft :: ![Int] @@ -154,9 +179,20 @@ data Sweep = Sweep , swPrev :: !Double } -sweepFeature :: CartCtx -> [Int] -> VU.Vector Int -> CartFeature -> Int -> Maybe (Double, Double) +sweepFeature :: + CartCtx -> + [Int] -> + VU.Vector Int -> + CartFeature -> + Int -> + Maybe (Double, Double) sweepFeature ctx total si feat n = - swBest (foldl' step (Sweep Nothing (replicate (ctxNClasses ctx) 0) 0 (0 / 0)) [0 .. VU.length si - 1]) + swBest + ( foldl' + step + (Sweep Nothing (replicate (ctxNClasses ctx) 0) 0 (0 / 0)) + [0 .. VU.length si - 1] + ) where vals = cfValues feat step s k = advance ctx total n (vals VU.! i) (ctxCodes ctx VU.! i) s @@ -165,24 +201,35 @@ sweepFeature ctx total si feat n = advance :: CartCtx -> [Int] -> Int -> Double -> Int -> Sweep -> Sweep advance ctx total n v c s = - Sweep (considerThreshold ctx total n v s) (bumpClass c (swLeft s)) (swMoved s + 1) v + Sweep + (considerThreshold ctx total n v s) + (bumpClass c (swLeft s)) + (swMoved s + 1) + v -considerThreshold :: CartCtx -> [Int] -> Int -> Double -> Sweep -> Maybe (Double, Double) +considerThreshold :: + CartCtx -> [Int] -> Int -> Double -> Sweep -> Maybe (Double, Double) considerThreshold ctx total n v s | swMoved s >= ctxMinLeaf ctx , n - swMoved s >= ctxMinLeaf ctx , v > swPrev s + 1e-7 = - keepBetter (swBest s) (weightedGini total (swLeft s) (swMoved s) n) ((swPrev s + v) / 2) + keepBetter + (swBest s) + (weightedGini total (swLeft s) (swMoved s) n) + ((swPrev s + v) / 2) | otherwise = swBest s -keepBetter :: Maybe (Double, Double) -> Double -> Double -> Maybe (Double, Double) +keepBetter :: + Maybe (Double, Double) -> Double -> Double -> Maybe (Double, Double) keepBetter best g thr = case best of Just (wb, _) | wb <= g -> best _ -> Just (g, thr) weightedGini :: [Int] -> [Int] -> Int -> Int -> Double weightedGini total leftAcc nl n = - (fromIntegral nl * giniImpurity leftAcc nl + fromIntegral nr * giniImpurity rightAcc nr) + ( fromIntegral nl * giniImpurity leftAcc nl + + fromIntegral nr * giniImpurity rightAcc nr + ) / fromIntegral n where nr = n - nl @@ -205,21 +252,27 @@ featuresOfColumn df c = case unsafeGetColumn c df of UnboxedColumn _ (v :: VU.Vector b) -> numericFeature @b c v BoxedColumn _ (v :: V.Vector b) -> oneHotFeatures @b (nRows df) c v -numericFeature :: forall b. (Columnable b, VU.Unbox b) => T.Text -> VU.Vector b -> [CartFeature] +numericFeature :: + forall b. (Columnable b, VU.Unbox b) => T.Text -> VU.Vector b -> [CartFeature] numericFeature c v = case testEquality (typeRep @b) (typeRep @Double) of Just Refl -> [CartFeature v (\t -> F.col @Double c .<=. F.lit t)] Nothing -> case sIntegral @b of - STrue -> [CartFeature (VU.map fromIntegral v) (\t -> F.toDouble (F.col @b c) .<=. F.lit t)] + STrue -> + [ CartFeature (VU.map fromIntegral v) (\t -> F.toDouble (F.col @b c) .<=. F.lit t) + ] SFalse -> [] -oneHotFeatures :: forall b. (Columnable b) => Int -> T.Text -> V.Vector b -> [CartFeature] +oneHotFeatures :: + forall b. (Columnable b) => Int -> T.Text -> V.Vector b -> [CartFeature] oneHotFeatures nAll c v = case testEquality (typeRep @b) (typeRep @T.Text) of Just Refl -> [oneHot nAll c v cat | cat <- Set.toList (Set.fromList (V.toList v))] Nothing -> [] oneHot :: Int -> T.Text -> V.Vector T.Text -> T.Text -> CartFeature oneHot nAll c v cat = - CartFeature (VU.generate nAll (\i -> if v V.! i == cat then 1 else 0)) (const (F.col @T.Text c ./=. F.lit cat)) + CartFeature + (VU.generate nAll (\i -> if v V.! i == cat then 1 else 0)) + (const (F.col @T.Text c ./=. F.lit cat)) -- | Target column as string labels (matches pandas @y.astype(str)@). cartTargetLabels :: T.Text -> DataFrame -> V.Vector T.Text diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Categorical.hs b/dataframe-learn/src/DataFrame/DecisionTree/Categorical.hs index d00ec4b..b94df04 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Categorical.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Categorical.hs @@ -5,10 +5,11 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Categorical split candidates: Breiman prefixes for binary targets, --- subset/singleton enumeration otherwise, and cross-column equality. Each --- value-list yields an OR-of-equalities condition (as an expression or a --- directly-read membership truth vector). +{- | Categorical split candidates: Breiman prefixes for binary targets, +subset/singleton enumeration otherwise, and cross-column equality. Each +value-list yields an OR-of-equalities condition (as an expression or a +directly-read membership truth vector). +-} module DataFrame.DecisionTree.Categorical ( TargetInfo (..), mkTargetInfo, @@ -29,7 +30,12 @@ module DataFrame.DecisionTree.Categorical ( ) where import DataFrame.DecisionTree.CondVec (CondVec (..), materializeCondVec) -import DataFrame.DecisionTree.Types (ColumnOrdering, SynthConfig (..), TreeConfig (..), withOrdFrom) +import DataFrame.DecisionTree.Types ( + ColumnOrdering, + SynthConfig (..), + TreeConfig (..), + withOrdFrom, + ) import DataFrame.Internal.Column import DataFrame.Internal.DataFrame (DataFrame, columnNames, unsafeGetColumn) import DataFrame.Internal.Expression (Expr (..)) @@ -53,19 +59,25 @@ import Type.Reflection (typeRep) validBoxedValues :: Bitmap -> V.Vector a -> V.Vector a validBoxedValues bm = V.ifilter (\i _ -> bitmapTestBit bm i) --- | Target-column summary driving the categorical generator: binary vs --- multi-class, the deterministic positive class, and the raw label vector. +{- | Target-column summary driving the categorical generator: binary vs +multi-class, the deterministic positive class, and the raw label vector. +-} data TargetInfo target = TargetInfo { tiIsBinary :: !Bool , tiPositiveClass :: !(Maybe target) , tiValues :: !(V.Vector target) } --- | Compute 'TargetInfo' once per fit. The positive class for binary targets --- is the lexicographically-first distinct value, for deterministic pools. -mkTargetInfo :: forall target. (Columnable target, Ord target) => T.Text -> DataFrame -> Maybe (TargetInfo target) +{- | Compute 'TargetInfo' once per fit. The positive class for binary targets +is the lexicographically-first distinct value, for deterministic pools. +-} +mkTargetInfo :: + forall target. + (Columnable target, Ord target) => + T.Text -> DataFrame -> Maybe (TargetInfo target) mkTargetInfo target df = case interpret @target df (Col target) of - Right (TColumn column) -> either (const Nothing) (Just . targetInfoFromValues) (toVector @target column) + Right (TColumn column) -> + either (const Nothing) (Just . targetInfoFromValues) (toVector @target column) _ -> Nothing targetInfoFromValues :: (Ord target) => V.Vector target -> TargetInfo target @@ -77,8 +89,9 @@ targetInfoFromValues vals = TargetInfo isBinary posClass vals (p : _) | isBinary -> Just p _ -> Nothing --- | Distinct values, capped: @Right vs@ (sorted) under the cap, else @Left@ --- the count-so-far so the caller routes to the high-cardinality path. +{- | Distinct values, capped: @Right vs@ (sorted) under the cap, else @Left@ +the count-so-far so the caller routes to the high-cardinality path. +-} distinctValuesUpTo :: (Ord a) => Int -> V.Vector a -> Either Int [a] distinctValuesUpTo cap values = go Set.empty 0 where @@ -88,8 +101,9 @@ distinctValuesUpTo cap values = go Set.empty 0 | Set.size s > cap = Left (Set.size s) | otherwise = go (Set.insert (V.unsafeIndex values i) s) (i + 1) --- | OR-of-equalities for a value-list, shared by the expression and --- truth-vector discrete paths so they stay byte-identical. +{- | OR-of-equalities for a value-list, shared by the expression and +truth-vector discrete paths so they stay byte-identical. +-} orEqs :: (a -> Expr Bool) -> [a] -> Expr Bool orEqs eqLit = foldr1 (.||.) . map eqLit @@ -106,17 +120,28 @@ singletonSplits = map singletonLists :: [a] -> [[a]] singletonLists = map (: []) -breimanPrefixSplits :: (Ord a, Ord target) => target -> V.Vector a -> V.Vector target -> [a] -> (a -> Expr Bool) -> [Expr Bool] +breimanPrefixSplits :: + (Ord a, Ord target) => + target -> + V.Vector a -> + V.Vector target -> + [a] -> + (a -> Expr Bool) -> + [Expr Bool] breimanPrefixSplits pc values targetVals distinctVals eqLit = map (orEqs eqLit) (breimanPrefixLists pc values targetVals distinctVals) --- | Breiman's binary-target split set: sort levels by Laplace-smoothed --- positive rate, then take every contiguous non-trivial prefix. -breimanPrefixLists :: (Ord a, Ord target) => target -> V.Vector a -> V.Vector target -> [a] -> [[a]] +{- | Breiman's binary-target split set: sort levels by Laplace-smoothed +positive rate, then take every contiguous non-trivial prefix. +-} +breimanPrefixLists :: + (Ord a, Ord target) => target -> V.Vector a -> V.Vector target -> [a] -> [[a]] breimanPrefixLists pc values targetVals distinctVals = nonTrivialPrefixes (sortByRate (levelCounts pc values targetVals) distinctVals) -levelCounts :: (Ord a, Eq target) => target -> V.Vector a -> V.Vector target -> M.Map a (Int, Int) +levelCounts :: + (Ord a, Eq target) => + target -> V.Vector a -> V.Vector target -> M.Map a (Int, Int) levelCounts pc values targetVals = V.ifoldl' add M.empty values where add acc i v = M.insertWith plus v (indicator (V.unsafeIndex targetVals i == pc), 1) acc @@ -134,15 +159,19 @@ sortByRate counts = sortBy (compare `on` (\v -> (laplaceRate counts v, v))) nonTrivialPrefixes :: [a] -> [[a]] nonTrivialPrefixes = tail . init . inits --- | Value-lists a categorical column contributes; shared by the expression and --- truth-vector paths so both enumerate identical candidates in the same order. -catValueLists :: (Ord a, Ord target) => Bool -> Maybe target -> V.Vector target -> Int -> V.Vector a -> [[a]] +{- | Value-lists a categorical column contributes; shared by the expression and +truth-vector paths so both enumerate identical candidates in the same order. +-} +catValueLists :: + (Ord a, Ord target) => + Bool -> Maybe target -> V.Vector target -> Int -> V.Vector a -> [[a]] catValueLists isBinary posClass targetVals subsetCap values | V.null values = [] | isBinary, Just pc <- posClass = binaryLists pc targetVals values | otherwise = multiclassLists subsetCap values -binaryLists :: (Ord a, Ord target) => target -> V.Vector target -> V.Vector a -> [[a]] +binaryLists :: + (Ord a, Ord target) => target -> V.Vector target -> V.Vector a -> [[a]] binaryLists pc targetVals values | length distinct < 2 = [] | otherwise = breimanPrefixLists pc values targetVals distinct @@ -158,15 +187,17 @@ multiclassLists subsetCap values = case distinctValuesUpTo subsetCap values of ascDistinct :: (Ord a) => V.Vector a -> [a] ascDistinct = Set.toAscList . Set.fromList . V.toList --- | Truth vector of @col ∈ values@ read directly from the column; equal to --- interpreting @orEqs (== v) values@ because the values are distinct. +{- | Truth vector of @col ∈ values@ read directly from the column; equal to +interpreting @orEqs (== v) values@ because the values are distinct. +-} membershipVec :: (Ord a) => V.Vector a -> [a] -> VU.Vector Bool membershipVec colVals vs = let !s = Set.fromList vs in VU.generate (V.length colVals) (\i -> Set.member (colVals `V.unsafeIndex` i) s) --- | Per-fit categorical generation context bundling the target summary and --- the column-ordering registry. +{- | Per-fit categorical generation context bundling the target summary and +the column-ordering registry. +-} data CatCtx target = CatCtx { ccBinary :: !Bool , ccPos :: !(Maybe target) @@ -177,7 +208,12 @@ data CatCtx target = CatCtx catCtx :: TargetInfo target -> TreeConfig -> CatCtx target catCtx ti cfg = - CatCtx (tiIsBinary ti) (tiPositiveClass ti) (tiValues ti) (maxCategoricalSubsetCardinality (synthConfig cfg)) (columnOrdering cfg) + CatCtx + (tiIsBinary ti) + (tiPositiveClass ti) + (tiValues ti) + (maxCategoricalSubsetCardinality (synthConfig cfg)) + (columnOrdering cfg) catValueListsFor :: (Ord a, Ord target) => CatCtx target -> V.Vector a -> [[a]] catValueListsFor ctx = catValueLists (ccBinary ctx) (ccPos ctx) (ccTargets ctx) (ccSubsetCap ctx) @@ -190,26 +226,50 @@ isNumericKind = case sFloating @a of STrue -> True SFalse -> False --- | All equality-based candidate splits from non-numeric columns: per-column --- categorical conditions plus cross-column equality/order conditions. -discreteConditions :: forall target. (Columnable target, Ord target) => TargetInfo target -> TreeConfig -> DataFrame -> [Expr Bool] +{- | All equality-based candidate splits from non-numeric columns: per-column +categorical conditions plus cross-column equality/order conditions. +-} +discreteConditions :: + forall target. + (Columnable target, Ord target) => + TargetInfo target -> TreeConfig -> DataFrame -> [Expr Bool] discreteConditions targetInfo cfg df = - concatMap (columnConds (catCtx targetInfo cfg) df) (columnNames df) ++ crossColumnConds cfg df + concatMap (columnConds (catCtx targetInfo cfg) df) (columnNames df) + ++ crossColumnConds cfg df -columnConds :: (Columnable target, Ord target) => CatCtx target -> DataFrame -> T.Text -> [Expr Bool] +columnConds :: + (Columnable target, Ord target) => + CatCtx target -> DataFrame -> T.Text -> [Expr Bool] columnConds ctx df colName = case unsafeGetColumn colName df of BoxedColumn Nothing (column :: V.Vector a) -> nonNullColConds ctx colName column BoxedColumn (Just bm) (column :: V.Vector a) -> nullableColConds ctx colName bm column UnboxedColumn _ (_ :: VU.Vector a) -> [] -nonNullColConds :: forall a target. (Columnable a, Ord target) => CatCtx target -> T.Text -> V.Vector a -> [Expr Bool] +nonNullColConds :: + forall a target. + (Columnable a, Ord target) => + CatCtx target -> T.Text -> V.Vector a -> [Expr Bool] nonNullColConds ctx colName column = - fromMaybe [] (withOrdFrom @a (ccOrds ctx) (map (orEqs (eqExprFor @a colName)) (catValueListsFor ctx column))) - -nullableColConds :: forall a target. (Columnable a, Ord target) => CatCtx target -> T.Text -> Bitmap -> V.Vector a -> [Expr Bool] + fromMaybe + [] + ( withOrdFrom @a + (ccOrds ctx) + (map (orEqs (eqExprFor @a colName)) (catValueListsFor ctx column)) + ) + +nullableColConds :: + forall a target. + (Columnable a, Ord target) => + CatCtx target -> T.Text -> Bitmap -> V.Vector a -> [Expr Bool] nullableColConds ctx colName bm column | isNumericKind @a || V.null valid = [] - | otherwise = fromMaybe [] (withOrdFrom @a (ccOrds ctx) (map (orEqs (eqJustFor @a colName)) (catValueListsFor ctx valid))) + | otherwise = + fromMaybe + [] + ( withOrdFrom @a + (ccOrds ctx) + (map (orEqs (eqJustFor @a colName)) (catValueListsFor ctx valid)) + ) where valid = validBoxedValues bm column @@ -225,11 +285,18 @@ crossColumnConds cfg df = concatMap (pairConds (columnOrdering cfg) df) (allowed allowedPairs :: TreeConfig -> DataFrame -> [(T.Text, T.Text)] allowedPairs cfg df = - [(l, r) | l <- columnNames df, r <- columnNames df, l /= r, not (isDisallowedPair cfg l r)] + [ (l, r) + | l <- columnNames df + , r <- columnNames df + , l /= r + , not (isDisallowedPair cfg l r) + ] isDisallowedPair :: TreeConfig -> T.Text -> T.Text -> Bool isDisallowedPair cfg l r = - any (\(l', r') -> sort [l', r'] == sort [l, r]) (disallowedCombinations (synthConfig cfg)) + any + (\(l', r') -> sort [l', r'] == sort [l, r]) + (disallowedCombinations (synthConfig cfg)) pairConds :: ColumnOrdering -> DataFrame -> (T.Text, T.Text) -> [Expr Bool] pairConds ords df (l, r) = case (unsafeGetColumn l df, unsafeGetColumn r df) of @@ -237,20 +304,29 @@ pairConds ords df (l, r) = case (unsafeGetColumn l df, unsafeGetColumn r df) of (BoxedColumn (Just _) (_ :: V.Vector a), BoxedColumn (Just _) (_ :: V.Vector b)) -> nullablePairConds @a @b ords l r _ -> [] -strictPairConds :: forall a b. (Columnable a, Columnable b) => T.Text -> T.Text -> [Expr Bool] +strictPairConds :: + forall a b. (Columnable a, Columnable b) => T.Text -> T.Text -> [Expr Bool] strictPairConds l r = case testEquality (typeRep @a) (typeRep @b) of Just Refl -> [Col @a l .==. Col @a r] Nothing -> [] -nullablePairConds :: forall a b. (Columnable a, Columnable b) => ColumnOrdering -> T.Text -> T.Text -> [Expr Bool] +nullablePairConds :: + forall a b. + (Columnable a, Columnable b) => + ColumnOrdering -> T.Text -> T.Text -> [Expr Bool] nullablePairConds ords l r = case testEquality (typeRep @a) (typeRep @b) of Nothing -> [] Just Refl -> nullableEqOrLe @a ords l r -nullableEqOrLe :: forall a. (Columnable a) => ColumnOrdering -> T.Text -> T.Text -> [Expr Bool] +nullableEqOrLe :: + forall a. (Columnable a) => ColumnOrdering -> T.Text -> T.Text -> [Expr Bool] nullableEqOrLe ords l r | isTextType @a = eqOnly - | otherwise = maybe eqOnly (++ eqOnly) (withOrdFrom @a ords [Col @(Maybe a) l .<=. Col @(Maybe a) r]) + | otherwise = + maybe + eqOnly + (++ eqOnly) + (withOrdFrom @a ords [Col @(Maybe a) l .<=. Col @(Maybe a) r]) where eqOnly = [Col @(Maybe a) l .==. Col @(Maybe a) r] @@ -259,23 +335,37 @@ isTextType = case testEquality (typeRep @a) (typeRep @T.Text) of Just Refl -> True Nothing -> False --- | 'discreteConditions' materialized with shared per-column reads: the --- non-nullable categorical path builds truth vectors directly from one read --- per column; nullable and cross-column fall back to interpret. -discreteCondVecs :: forall target. (Columnable target, Ord target) => TargetInfo target -> TreeConfig -> DataFrame -> [CondVec] +{- | 'discreteConditions' materialized with shared per-column reads: the +non-nullable categorical path builds truth vectors directly from one read +per column; nullable and cross-column fall back to interpret. +-} +discreteCondVecs :: + forall target. + (Columnable target, Ord target) => + TargetInfo target -> TreeConfig -> DataFrame -> [CondVec] discreteCondVecs targetInfo cfg df = concatMap (columnCondVecs (catCtx targetInfo cfg) df) (columnNames df) ++ mapMaybe (materializeCondVec df) (crossColumnConds cfg df) -columnCondVecs :: (Columnable target, Ord target) => CatCtx target -> DataFrame -> T.Text -> [CondVec] +columnCondVecs :: + (Columnable target, Ord target) => + CatCtx target -> DataFrame -> T.Text -> [CondVec] columnCondVecs ctx df colName = case unsafeGetColumn colName df of BoxedColumn Nothing (column :: V.Vector a) -> nonNullColCondVecs ctx colName column BoxedColumn (Just bm) (column :: V.Vector a) -> mapMaybe (materializeCondVec df) (nullableColConds ctx colName bm column) UnboxedColumn _ (_ :: VU.Vector a) -> [] -nonNullColCondVecs :: forall a target. (Columnable a, Ord target) => CatCtx target -> T.Text -> V.Vector a -> [CondVec] +nonNullColCondVecs :: + forall a target. + (Columnable a, Ord target) => CatCtx target -> T.Text -> V.Vector a -> [CondVec] nonNullColCondVecs ctx colName column = - fromMaybe [] (withOrdFrom @a (ccOrds ctx) (map (membershipCondVec colName column) (catValueListsFor ctx column))) - -membershipCondVec :: forall a. (Columnable a, Ord a) => T.Text -> V.Vector a -> [a] -> CondVec + fromMaybe + [] + ( withOrdFrom @a + (ccOrds ctx) + (map (membershipCondVec colName column) (catValueListsFor ctx column)) + ) + +membershipCondVec :: + forall a. (Columnable a, Ord a) => T.Text -> V.Vector a -> [a] -> CondVec membershipCondVec colName column vs = CondVec (orEqs (eqExprFor @a colName) vs) (membershipVec column vs) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/CondVec.hs b/dataframe-learn/src/DataFrame/DecisionTree/CondVec.hs index e928304..3817e22 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/CondVec.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/CondVec.hs @@ -3,9 +3,10 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Cached condition truth vectors and the per-fit cache keyed by structural --- form. A condition's truth over a fixed DataFrame is invariant for a whole --- fit, so it is materialized once and reused. +{- | Cached condition truth vectors and the per-fit cache keyed by structural +form. A condition's truth over a fixed DataFrame is invariant for a whole +fit, so it is materialized once and reused. +-} module DataFrame.DecisionTree.CondVec ( CondVec (..), materializeCondVec, @@ -25,7 +26,12 @@ import DataFrame.DecisionTree.Types (CarePoint (..), Direction (..), Tree (..)) import qualified DataFrame.Functions as F import DataFrame.Internal.Column (TypedColumn (..), toVector) import DataFrame.Internal.DataFrame (DataFrame) -import DataFrame.Internal.Expression (BinaryOp (binaryName), Expr (..), eqExpr, normalize) +import DataFrame.Internal.Expression ( + BinaryOp (binaryName), + Expr (..), + eqExpr, + normalize, + ) import DataFrame.Internal.Interpreter (interpret) import qualified Data.Map.Strict as M @@ -42,8 +48,9 @@ data CondVec = CondVec , cvVec :: !(VU.Vector Bool) } --- | Interpret a condition once over the DataFrame; 'Nothing' on a --- type/interpret failure so the candidate is silently dropped. +{- | Interpret a condition once over the DataFrame; 'Nothing' on a +type/interpret failure so the candidate is silently dropped. +-} materializeCondVec :: DataFrame -> Expr Bool -> Maybe CondVec materializeCondVec df cond = case interpret @Bool df cond of Left _ -> Nothing @@ -52,13 +59,15 @@ materializeCondVec df cond = case interpret @Bool df cond of eitherToMaybe :: Either e a -> Maybe a eitherToMaybe = either (const Nothing) Just --- | Full-DataFrame truth vectors keyed by structural form, read-only once --- built. Seeded for free from the candidate pool plus the initial tree so the --- predict/loss passes index a vector instead of re-interpreting per node. +{- | Full-DataFrame truth vectors keyed by structural form, read-only once +built. Seeded for free from the candidate pool plus the initial tree so the +predict/loss passes index a vector instead of re-interpreting per node. +-} type CondCache = M.Map T.Text (VU.Vector Bool) --- | Structural key matching the candidate-dedup key, so a tree branch whose --- condition came from the pool hits the cache (equal keys ⟹ equal vector). +{- | Structural key matching the candidate-dedup key, so a tree branch whose +condition came from the pool hits the cache (equal keys ⟹ equal vector). +-} condCacheKey :: Expr Bool -> T.Text condCacheKey = T.pack . show . normalize @@ -66,8 +75,9 @@ condCacheKey = T.pack . show . normalize condCacheFromVecs :: [CondVec] -> CondCache condCacheFromVecs cvs = M.fromList [(condCacheKey (cvExpr cv), cvVec cv) | cv <- cvs] --- | Add a tree's branch-condition vectors to a cache (one interpret per --- distinct, not-yet-cached condition). +{- | Add a tree's branch-condition vectors to a cache (one interpret per +distinct, not-yet-cached condition). +-} addTreeCondsToCache :: DataFrame -> Tree a -> CondCache -> CondCache addTreeCondsToCache df = go where @@ -77,12 +87,14 @@ addTreeCondsToCache df = go insertCond :: DataFrame -> Expr Bool -> CondCache -> CondCache insertCond df cond c | M.member k c = c - | otherwise = maybe c (\cv -> M.insert k (cvVec cv) c) (materializeCondVec df cond) + | otherwise = + maybe c (\cv -> M.insert k (cvVec cv) c) (materializeCondVec df cond) where k = condCacheKey cond --- | A condition's truth vector: a cache hit, else interpret over the --- DataFrame. 'Nothing' mirrors the interpret-failure fallback (route left). +{- | A condition's truth vector: a cache hit, else interpret over the +DataFrame. 'Nothing' mirrors the interpret-failure fallback (route left). +-} lookupCondVec :: CondCache -> DataFrame -> Expr Bool -> Maybe (VU.Vector Bool) lookupCondVec cache df cond = case M.lookup (condCacheKey cond) cache of hit@(Just _) -> hit @@ -98,8 +110,9 @@ countErrorsByVec boolVals = length . filter misrouted where misrouted cp = (boolVals VU.! cpIndex cp) /= (cpCorrectDir cp == GoLeft) --- | A same-column same-direction Double threshold comparison, with a rebuild --- function to swap in a new threshold. +{- | A same-column same-direction Double threshold comparison, with a rebuild +function to swap in a new threshold. +-} data ThreshCmp = ThreshCmp { tcCol :: !T.Text , tcName :: !T.Text @@ -109,7 +122,9 @@ data ThreshCmp = ThreshCmp asDoubleThreshold :: Expr Bool -> Maybe ThreshCmp asDoubleThreshold (Binary op (Col c :: Expr cc) (Lit (t :: tt))) = - case (testEquality (typeRep @cc) (typeRep @Double), testEquality (typeRep @tt) (typeRep @Double)) of + case ( testEquality (typeRep @cc) (typeRep @Double) + , testEquality (typeRep @tt) (typeRep @Double) + ) of (Just Refl, Just Refl) -> Just (ThreshCmp c (binaryName op) t (Binary op (Col c) . Lit)) _ -> Nothing asDoubleThreshold _ = Nothing @@ -117,8 +132,9 @@ asDoubleThreshold _ = Nothing directionalNames :: [T.Text] directionalNames = ["lt", "leq", "gt", "geq"] --- | Tighter (AND) or looser (OR) of two same-direction thresholds: @<@/@<=@ --- are left-half-spaces (AND = min), @>@/@>=@ are right-half-spaces (AND = max). +{- | Tighter (AND) or looser (OR) of two same-direction thresholds: @<@/@<=@ +are left-half-spaces (AND = min), @>@/@>=@ are right-half-spaces (AND = max). +-} chooseThreshold :: Bool -> T.Text -> Double -> Double -> Double chooseThreshold isAnd name t1 t2 | leftDir = if isAnd then min t1 t2 else max t1 t2 @@ -126,8 +142,9 @@ chooseThreshold isAnd name t1 t2 where leftDir = name == "lt" || name == "leq" --- | Collapse two same-column same-direction strict-Double comparisons into one --- comparison (the @True@ argument selects AND, @False@ OR); 'Nothing' otherwise. +{- | Collapse two same-column same-direction strict-Double comparisons into one +comparison (the @True@ argument selects AND, @False@ OR); 'Nothing' otherwise. +-} consolidateThreshold :: Bool -> Expr Bool -> Expr Bool -> Maybe (Expr Bool) consolidateThreshold isAnd ea eb = do a <- asDoubleThreshold ea @@ -136,20 +153,28 @@ consolidateThreshold isAnd ea eb = do then Just (tcRebuild a (chooseThreshold isAnd (tcName a) (tcThr a) (tcThr b))) else Nothing --- | AND-combine two cached conditions: idempotence and threshold consolidation --- first, else the generic @F.and@; the vector is always the elementwise AND. +{- | AND-combine two cached conditions: idempotence and threshold consolidation +first, else the generic @F.and@; the vector is always the elementwise AND. +-} combineAndVec :: CondVec -> CondVec -> CondVec combineAndVec a b | eqExpr (cvExpr a) (cvExpr b) = a | otherwise = CondVec expr (VU.zipWith (&&) (cvVec a) (cvVec b)) where - expr = fromMaybe (F.and (cvExpr a) (cvExpr b)) (consolidateThreshold True (cvExpr a) (cvExpr b)) - --- | OR-combine two cached conditions (see 'combineAndVec'; AND/OR direction --- differs in 'consolidateThreshold'). + expr = + fromMaybe + (F.and (cvExpr a) (cvExpr b)) + (consolidateThreshold True (cvExpr a) (cvExpr b)) + +{- | OR-combine two cached conditions (see 'combineAndVec'; AND/OR direction +differs in 'consolidateThreshold'). +-} combineOrVec :: CondVec -> CondVec -> CondVec combineOrVec a b | eqExpr (cvExpr a) (cvExpr b) = a | otherwise = CondVec expr (VU.zipWith (||) (cvVec a) (cvVec b)) where - expr = fromMaybe (F.or (cvExpr a) (cvExpr b)) (consolidateThreshold False (cvExpr a) (cvExpr b)) + expr = + fromMaybe + (F.or (cvExpr a) (cvExpr b)) + (consolidateThreshold False (cvExpr a) (cvExpr b)) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Fit.hs b/dataframe-learn/src/DataFrame/DecisionTree/Fit.hs index c75968d..1d6b398 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Fit.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Fit.hs @@ -3,9 +3,10 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Top-level fitting: assemble the candidate pool, seed from CART, run TAO, --- and convert the result to an expression. Also the probability-tree variant --- ('fitProbTree') that annotates leaves with class distributions. +{- | Top-level fitting: assemble the candidate pool, seed from CART, run TAO, +and convert the result to an expression. Also the probability-tree variant +('fitProbTree') that annotates leaves with class distributions. +-} module DataFrame.DecisionTree.Fit ( treeToExpr, fitDecisionTree, @@ -24,7 +25,12 @@ module DataFrame.DecisionTree.Fit ( ) where import DataFrame.DecisionTree.Cart (buildCartTree) -import DataFrame.DecisionTree.Categorical (TargetInfo (..), discreteConditions, discreteCondVecs, mkTargetInfo) +import DataFrame.DecisionTree.Categorical ( + TargetInfo (..), + discreteCondVecs, + discreteConditions, + mkTargetInfo, + ) import DataFrame.DecisionTree.CondVec (CondVec) import DataFrame.DecisionTree.Numeric (numericCondVecs, numericConditions) import DataFrame.DecisionTree.Pool (dedupCVByExpr, nubByExpr) @@ -54,9 +60,11 @@ treeToExpr (Leaf v) = Lit v treeToExpr (Branch cond left right) = F.ifThenElse cond (treeToExpr left) (treeToExpr right) -- | Fit a TAO decision tree (CART-seeded) and return it as an expression. -fitDecisionTree :: forall a. (Columnable a, Ord a) => TreeConfig -> Expr a -> DataFrame -> Expr a +fitDecisionTree :: + forall a. (Columnable a, Ord a) => TreeConfig -> Expr a -> DataFrame -> Expr a fitDecisionTree cfg (Col target) df = - pruneExpr (treeToExpr (taoOptimizeCV @a cfg target condVecs df indices initialTree)) + pruneExpr + (treeToExpr (taoOptimizeCV @a cfg target condVecs df indices initialTree)) where condVecs = candidatePool @a cfg target df initialTree = buildCartTree @a cfg target df @@ -64,18 +72,24 @@ fitDecisionTree cfg (Col target) df = fitDecisionTree _ expr _ = error ("Cannot create tree for compound expression: " ++ show expr) -- | The deduplicated numeric + discrete candidate pool for a target column. -candidatePool :: forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> DataFrame -> [CondVec] +candidatePool :: + forall a. + (Columnable a, Ord a) => TreeConfig -> T.Text -> DataFrame -> [CondVec] candidatePool cfg target df = dedupCVByExpr (numericCVs ++ discreteCVs) where dfNoTarget = exclude [target] df numericCVs = numericCondVecs cfg dfNoTarget df discreteCVs = discreteCondVecs (targetInfoOrEmpty @a target df) cfg dfNoTarget -targetInfoOrEmpty :: forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> TargetInfo a +targetInfoOrEmpty :: + forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> TargetInfo a targetInfoOrEmpty target df = fromMaybe (TargetInfo False Nothing V.empty) (mkTargetInfo @a target df) -- | Fit a tree at a given depth from a raw condition list (CART + TAO + prune). -buildTree :: forall a. (Columnable a, Ord a) => TreeConfig -> Int -> T.Text -> [Expr Bool] -> DataFrame -> Expr a +buildTree :: + forall a. + (Columnable a, Ord a) => + TreeConfig -> Int -> T.Text -> [Expr Bool] -> DataFrame -> Expr a buildTree cfg depth target conds df = pruneExpr (treeToExpr (taoOptimize @a cfg target conds df indices tree)) where @@ -89,7 +103,8 @@ partitionDataFrame :: Expr Bool -> DataFrame -> (DataFrame, DataFrame) partitionDataFrame cond df = (filterWhere cond df, filterWhere (F.not cond) df) -- | Laplace-smoothed Gini impurity of the target distribution. -calculateGini :: forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> Double +calculateGini :: + forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> Double calculateGini target df | n == 0 = 0 | otherwise = 1 - sum (map (^ (2 :: Int)) probs) @@ -106,7 +121,8 @@ majorityValue target df where counts = getCounts @a target df -getCounts :: forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> M.Map a Int +getCounts :: + forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> M.Map a Int getCounts target df = case interpret @a df (Col target) of Left e -> throw e Right (TColumn column) -> case toVector @a column of @@ -131,7 +147,9 @@ percentileOfVec p vals type ProbTree a = Tree (M.Map a Double) -- | Normalised class probabilities over a subset of training rows. -probsFromIndices :: forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> V.Vector Int -> M.Map a Double +probsFromIndices :: + forall a. + (Columnable a, Ord a) => T.Text -> DataFrame -> V.Vector Int -> M.Map a Double probsFromIndices target df indices = case interpret @a df (Col target) of Right (TColumn column) -> either (const M.empty) (normaliseCounts indices) (toVector @a column) _ -> M.empty @@ -139,30 +157,51 @@ probsFromIndices target df indices = case interpret @a df (Col target) of normaliseCounts :: (Ord a) => V.Vector Int -> V.Vector a -> M.Map a Double normaliseCounts indices vals = M.map (\c -> fromIntegral c / total) counts where - counts = V.foldl' (\acc i -> M.insertWith (+) (vals V.! i) (1 :: Int) acc) M.empty indices + counts = + V.foldl' + (\acc i -> M.insertWith (+) (vals V.! i) (1 :: Int) acc) + M.empty + indices total = fromIntegral (V.length indices) :: Double --- | Re-label a fitted tree's leaves with class distributions, routing the --- training data through the (unchanged) split conditions. -buildProbTree :: forall a. (Columnable a, Ord a) => Tree a -> T.Text -> DataFrame -> V.Vector Int -> ProbTree a +{- | Re-label a fitted tree's leaves with class distributions, routing the +training data through the (unchanged) split conditions. +-} +buildProbTree :: + forall a. + (Columnable a, Ord a) => + Tree a -> T.Text -> DataFrame -> V.Vector Int -> ProbTree a buildProbTree (Leaf _) target df indices = Leaf (probsFromIndices @a target df indices) buildProbTree (Branch cond left right) target df indices = - Branch cond (buildProbTree @a left target df l) (buildProbTree @a right target df r) + Branch + cond + (buildProbTree @a left target df l) + (buildProbTree @a right target df r) where (l, r) = partitionIndices cond df indices -- | Fit a TAO tree and return one probability expression per class. -fitProbTree :: forall a. (Columnable a, Ord a) => TreeConfig -> Expr a -> DataFrame -> M.Map a (Expr Double) +fitProbTree :: + forall a. + (Columnable a, Ord a) => + TreeConfig -> Expr a -> DataFrame -> M.Map a (Expr Double) fitProbTree cfg (Col target) df = probExprs (buildProbTree @a pruned target df indices) where - conds = nubByExpr (numericConditions cfg dfNoTarget ++ discreteConditions (targetInfoOrEmpty @a target df) cfg dfNoTarget) + conds = + nubByExpr + ( numericConditions cfg dfNoTarget + ++ discreteConditions (targetInfoOrEmpty @a target df) cfg dfNoTarget + ) dfNoTarget = exclude [target] df indices = V.enumFromN 0 (nRows df) - pruned = pruneDead (taoOptimize @a cfg target conds df indices (buildCartTree @a cfg target df)) + pruned = + pruneDead + (taoOptimize @a cfg target conds df indices (buildCartTree @a cfg target df)) fitProbTree _ expr _ = error ("Cannot create prob tree for compound expression: " ++ show expr) -- | Convert a 'ProbTree' into one @Expr Double@ per class. -probExprs :: forall a. (Columnable a, Ord a) => ProbTree a -> M.Map a (Expr Double) +probExprs :: + forall a. (Columnable a, Ord a) => ProbTree a -> M.Map a (Expr Double) probExprs tree = M.fromList [(c, classExpr c tree) | c <- nub (allClasses tree)] allClasses :: ProbTree a -> [a] diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Linear.hs b/dataframe-learn/src/DataFrame/DecisionTree/Linear.hs index ced98e5..ed389c8 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Linear.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Linear.hs @@ -1,9 +1,10 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} --- | Oblique split candidates: fit an L1-regularised logistic hyperplane to the --- care points (class-balanced) and convert it to a boolean condition, rejecting --- all-zero and degenerate (single-side) hyperplanes. +{- | Oblique split candidates: fit an L1-regularised logistic hyperplane to the +care points (class-balanced) and convert it to a boolean condition, rejecting +all-zero and degenerate (single-side) hyperplanes. +-} module DataFrame.DecisionTree.Linear ( bestLinearCandidate, fitLinearCandidate, @@ -15,7 +16,11 @@ module DataFrame.DecisionTree.Linear ( ) where import DataFrame.DecisionTree.Numeric (NumExpr (..), numericCols) -import DataFrame.DecisionTree.Types (CarePoint (..), Direction (..), TreeConfig (..)) +import DataFrame.DecisionTree.Types ( + CarePoint (..), + Direction (..), + TreeConfig (..), + ) import DataFrame.Internal.Column (TypedColumn (..), toVector) import DataFrame.Internal.DataFrame (DataFrame) import DataFrame.Internal.Expression (Expr, getColumns) @@ -27,18 +32,22 @@ import qualified Data.Text as T import qualified Data.Vector as V import qualified Data.Vector.Unboxed as VU --- | Best oblique candidate, or 'Nothing' when the linear path is disabled or --- there are too few care points to fit on. -bestLinearCandidate :: TreeConfig -> DataFrame -> [CarePoint] -> Maybe (Expr Bool) +{- | Best oblique candidate, or 'Nothing' when the linear path is disabled or +there are too few care points to fit on. +-} +bestLinearCandidate :: + TreeConfig -> DataFrame -> [CarePoint] -> Maybe (Expr Bool) bestLinearCandidate cfg df carePoints | not (useLinearSolver cfg) = Nothing | length carePoints < minCarePointsForLinear cfg = Nothing | otherwise = fitLinearCandidate cfg df carePoints --- | Fit an L1 logistic regression to the care points and convert the resulting --- hyperplane to a condition, or 'Nothing' when no numeric features exist or the --- fitted model is all-zero or degenerate. -fitLinearCandidate :: TreeConfig -> DataFrame -> [CarePoint] -> Maybe (Expr Bool) +{- | Fit an L1 logistic regression to the care points and convert the resulting +hyperplane to a condition, or 'Nothing' when no numeric features exist or the +fitted model is all-zero or degenerate. +-} +fitLinearCandidate :: + TreeConfig -> DataFrame -> [CarePoint] -> Maybe (Expr Bool) fitLinearCandidate cfg df carePoints = case materializedFeatures df carePoints of [] -> Nothing mats -> linearFromFeatures cfg carePoints mats @@ -46,7 +55,8 @@ fitLinearCandidate cfg df carePoints = case materializedFeatures df carePoints o materializedFeatures :: DataFrame -> [CarePoint] -> [(T.Text, VU.Vector Double)] materializedFeatures df carePoints = mapMaybe (materializeFeatureForCare df carePoints) (numericCols df) -linearFromFeatures :: TreeConfig -> [CarePoint] -> [(T.Text, VU.Vector Double)] -> Maybe (Expr Bool) +linearFromFeatures :: + TreeConfig -> [CarePoint] -> [(T.Text, VU.Vector Double)] -> Maybe (Expr Bool) linearFromFeatures cfg carePoints mats | VU.all (== 0) weights = Nothing | degenerateHyperplane rows weights (LS.lmIntercept model) = Nothing @@ -54,14 +64,20 @@ linearFromFeatures cfg carePoints mats where rows = careRowsFromFeatures (length carePoints) mats labels = careLabels carePoints - model = LS.fitL1Logistic (solverConfigFor cfg labels) rows labels (V.fromList (map fst mats)) + model = + LS.fitL1Logistic + (solverConfigFor cfg labels) + rows + labels + (V.fromList (map fst mats)) weights = LS.lmWeights model solverConfigFor :: TreeConfig -> VU.Vector Double -> LS.SolverConfig solverConfigFor cfg labels = (linearSolverConfig cfg){LS.scSampleWeights = classBalancedWeights labels} --- | Class-balanced sklearn-form weights @w_i = N / (2 · N_class)@ (mean 1), or --- 'Nothing' in the degenerate one-class case (uniform weighting). +{- | Class-balanced sklearn-form weights @w_i = N / (2 · N_class)@ (mean 1), or +'Nothing' in the degenerate one-class case (uniform weighting). +-} classBalancedWeights :: VU.Vector Double -> Maybe (VU.Vector Double) classBalancedWeights labels | nPos > 0 && nNeg > 0 = Just (VU.generate nCare weightAt) @@ -74,18 +90,25 @@ classBalancedWeights labels | VU.unsafeIndex labels i > 0 = fromIntegral nCare / (2 * fromIntegral nPos) | otherwise = fromIntegral nCare / (2 * fromIntegral nNeg) --- | A hyperplane is degenerate when every care row scores on the same side of --- zero (equivalent to an invalid split, caught upstream). -degenerateHyperplane :: V.Vector (VU.Vector Double) -> VU.Vector Double -> Double -> Bool +{- | A hyperplane is degenerate when every care row scores on the same side of +zero (equivalent to an invalid split, caught upstream). +-} +degenerateHyperplane :: + V.Vector (VU.Vector Double) -> VU.Vector Double -> Double -> Bool degenerateHyperplane rows weights bias = nCare > 0 && (VU.minimum scores > 0 || VU.maximum scores < 0) where nCare = V.length rows - scores = VU.generate nCare (\i -> VU.sum (VU.zipWith (*) weights (V.unsafeIndex rows i)) + bias) - --- | Per-care-point feature rows from materialized columns (each of length --- @nCare@, so indexing is in range). -careRowsFromFeatures :: Int -> [(T.Text, VU.Vector Double)] -> V.Vector (VU.Vector Double) + scores = + VU.generate + nCare + (\i -> VU.sum (VU.zipWith (*) weights (V.unsafeIndex rows i)) + bias) + +{- | Per-care-point feature rows from materialized columns (each of length +@nCare@, so indexing is in range). +-} +careRowsFromFeatures :: + Int -> [(T.Text, VU.Vector Double)] -> V.Vector (VU.Vector Double) careRowsFromFeatures nCare mats = V.generate nCare (\i -> VU.generate nFeat (\j -> snd (matsVec V.! j) VU.! i)) where @@ -94,7 +117,8 @@ careRowsFromFeatures nCare mats = -- | Solver labels: @+1@ when 'GoLeft' is correct, @-1@ otherwise. careLabels :: [CarePoint] -> VU.Vector Double -careLabels carePoints = VU.fromList [if cpCorrectDir cp == GoLeft then 1.0 else -1.0 | cp <- carePoints] +careLabels carePoints = + VU.fromList [if cpCorrectDir cp == GoLeft then 1.0 else -1.0 | cp <- carePoints] -- | First column referenced by an expression, or a placeholder when none. featName :: Expr b -> T.Text @@ -102,8 +126,9 @@ featName expr = case getColumns expr of (c : _) -> c [] -> "" --- | Replace missing values with the mean of present ones; 'Nothing' when --- nothing is present so the caller can drop the feature. +{- | Replace missing values with the mean of present ones; 'Nothing' when +nothing is present so the caller can drop the feature. +-} imputeMean :: [Maybe Double] -> Maybe (VU.Vector Double) imputeMean careRaw = case catMaybes careRaw of [] -> Nothing @@ -116,14 +141,17 @@ interpretDoubleVals df expr = case interpret @Double df expr of Right (TColumn column) -> either (const Nothing) Just (toVector @Double column) _ -> Nothing -interpretMaybeDoubleVals :: DataFrame -> Expr (Maybe Double) -> Maybe (V.Vector (Maybe Double)) +interpretMaybeDoubleVals :: + DataFrame -> Expr (Maybe Double) -> Maybe (V.Vector (Maybe Double)) interpretMaybeDoubleVals df expr = case interpret @(Maybe Double) df expr of Right (TColumn column) -> either (const Nothing) Just (toVector @(Maybe Double) column) _ -> Nothing --- | Materialize a 'NumExpr' over the care rows; 'Nothing' on interpret failure --- or (nullable) when no care point has a present value, else mean-imputed. -materializeFeatureForCare :: DataFrame -> [CarePoint] -> NumExpr -> Maybe (T.Text, VU.Vector Double) +{- | Materialize a 'NumExpr' over the care rows; 'Nothing' on interpret failure +or (nullable) when no care point has a present value, else mean-imputed. +-} +materializeFeatureForCare :: + DataFrame -> [CarePoint] -> NumExpr -> Maybe (T.Text, VU.Vector Double) materializeFeatureForCare df carePoints (NDouble expr) = do vals <- interpretDoubleVals df expr Just (featName expr, VU.fromList [vals V.! cpIndex cp | cp <- carePoints]) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Numeric.hs b/dataframe-learn/src/DataFrame/DecisionTree/Numeric.hs index b2d115a..59f326f 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Numeric.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Numeric.hs @@ -5,10 +5,11 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Numeric split candidates: per-column Double expressions, arithmetic --- expansion, and threshold conditions. 'numericCondVecs' materializes the --- pool with a single interpret per distinct expression, deriving every --- threshold/operator truth vector by direct comparison. +{- | Numeric split candidates: per-column Double expressions, arithmetic +expansion, and threshold conditions. 'numericCondVecs' materializes the +pool with a single interpret per distinct expression, deriving every +threshold/operator truth vector by direct comparison. +-} module DataFrame.DecisionTree.Numeric ( NumExpr (..), numExprCols, @@ -67,11 +68,29 @@ combineNumExprs :: NumExpr -> NumExpr -> [NumExpr] combineNumExprs (NDouble e1) (NDouble e2) = map NDouble [e1 .+ e2, e1 .- e2, e1 .* e2, safeDivD e1 e2] combineNumExprs (NDouble e1) (NMaybeDouble e2) = - map NMaybeDouble [e1 .+ e2, e1 .- e2, e1 .* e2, safeDivMaybe (F.fromMaybe False (e2 ./= F.lit (0 :: Double))) (e1 ./ e2)] + map + NMaybeDouble + [ e1 .+ e2 + , e1 .- e2 + , e1 .* e2 + , safeDivMaybe (F.fromMaybe False (e2 ./= F.lit (0 :: Double))) (e1 ./ e2) + ] combineNumExprs (NMaybeDouble e1) (NDouble e2) = - map NMaybeDouble [e1 .+ e2, e1 .- e2, e1 .* e2, safeDivMaybe (e2 ./= F.lit (0 :: Double)) (e1 ./ e2)] + map + NMaybeDouble + [ e1 .+ e2 + , e1 .- e2 + , e1 .* e2 + , safeDivMaybe (e2 ./= F.lit (0 :: Double)) (e1 ./ e2) + ] combineNumExprs (NMaybeDouble e1) (NMaybeDouble e2) = - map NMaybeDouble [e1 .+ e2, e1 .- e2, e1 .* e2, safeDivMaybe (F.fromMaybe False (e2 ./= F.lit (0 :: Double))) (e1 ./ e2)] + map + NMaybeDouble + [ e1 .+ e2 + , e1 .- e2 + , e1 .* e2 + , safeDivMaybe (F.fromMaybe False (e2 ./= F.lit (0 :: Double))) (e1 ./ e2) + ] numericConditions :: TreeConfig -> DataFrame -> [Expr Bool] numericConditions = generateNumericConds @@ -92,10 +111,14 @@ thresholdsForExpr cfg df e = condsFromExpr :: NumExpr -> Double -> [Expr Bool] condsFromExpr (NDouble e) t = [e .<= F.lit t, e .>= F.lit t, e .< F.lit t, e .> F.lit t] -condsFromExpr (NMaybeDouble e) t = map (F.fromMaybe False) [e .<= F.lit t, e .>= F.lit t, e .< F.lit t, e .> F.lit t] - --- | Percentile thresholds for a value list: sort once, index each percentile. --- Shared by 'generateNumericConds' and 'numericCondVecs' for identical results. +condsFromExpr (NMaybeDouble e) t = + map + (F.fromMaybe False) + [e .<= F.lit t, e .>= F.lit t, e .< F.lit t, e .> F.lit t] + +{- | Percentile thresholds for a value list: sort once, index each percentile. +Shared by 'generateNumericConds' and 'numericCondVecs' for identical results. +-} percentilesOf :: [Int] -> [Double] -> [Double] percentilesOf ps valsList | n == 0 = [] @@ -109,15 +132,17 @@ interpretDoubleCol df e = case interpret @Double df e of Right (TColumn column) -> either (const Nothing) Just (toVector @Double column) _ -> Nothing -interpretMaybeDoubleCol :: DataFrame -> Expr (Maybe Double) -> Maybe (V.Vector (Maybe Double)) +interpretMaybeDoubleCol :: + DataFrame -> Expr (Maybe Double) -> Maybe (V.Vector (Maybe Double)) interpretMaybeDoubleCol df e = case interpret @(Maybe Double) df e of Right (TColumn column) -> either (const Nothing) Just (toVector @(Maybe Double) column) _ -> Nothing --- | Materialize the numeric pool with one interpret per distinct expression, --- deriving each threshold/operator truth vector by direct comparison. --- Byte-identical to materializing 'numericConditions' one at a time, but --- avoids re-interpreting each LHS per threshold and operator. +{- | Materialize the numeric pool with one interpret per distinct expression, +deriving each threshold/operator truth vector by direct comparison. +Byte-identical to materializing 'numericConditions' one at a time, but +avoids re-interpreting each LHS per threshold and operator. +-} numericCondVecs :: TreeConfig -> DataFrame -> DataFrame -> [CondVec] numericCondVecs cfg dfGen df = concatMap forExpr (numericExprsWithTerms (synthConfig cfg) dfGen) where @@ -139,12 +164,14 @@ doubleCondsAt e vals n t = where gen p = VU.generate n (\i -> p (vals V.! i)) -condsForMaybe :: TreeConfig -> Expr (Maybe Double) -> V.Vector (Maybe Double) -> [CondVec] +condsForMaybe :: + TreeConfig -> Expr (Maybe Double) -> V.Vector (Maybe Double) -> [CondVec] condsForMaybe cfg e mvals = concatMap (maybeCondsAt e mvals (V.length mvals)) ts where ts = percentilesOf (percentiles cfg) (map (fromMaybe 0) (V.toList mvals)) -maybeCondsAt :: Expr (Maybe Double) -> V.Vector (Maybe Double) -> Int -> Double -> [CondVec] +maybeCondsAt :: + Expr (Maybe Double) -> V.Vector (Maybe Double) -> Int -> Double -> [CondVec] maybeCondsAt e mvals n t = [ CondVec (F.fromMaybe False (e .<= F.lit t)) (gen (<= t)) , CondVec (F.fromMaybe False (e .>= F.lit t)) (gen (>= t)) @@ -154,13 +181,15 @@ maybeCondsAt e mvals n t = where gen p = VU.generate n (\i -> maybe False p (mvals V.! i)) --- | Arithmetic candidate expansion, generated already-deduped: each round --- combines @frontier × base@ and admits only normalized-novel candidates. --- Produces @base@ plus @maxExprDepth-1@ combination rounds. +{- | Arithmetic candidate expansion, generated already-deduped: each round +combines @frontier × base@ and admits only normalized-novel candidates. +Produces @base@ plus @maxExprDepth-1@ combination rounds. +-} numericExprsWithTerms :: SynthConfig -> DataFrame -> [NumExpr] numericExprsWithTerms cfg df | not (enableArithOps cfg) = base - | otherwise = base ++ expandRounds cfg base (max 0 (maxExprDepth cfg - 1)) base seen0 + | otherwise = + base ++ expandRounds cfg base (max 0 (maxExprDepth cfg - 1)) base seen0 where base = numericCols df seen0 = Set.fromList (map keyNum base) @@ -177,9 +206,16 @@ isDisallowed cfg e1 e2 = roundProducts :: SynthConfig -> [NumExpr] -> [NumExpr] -> [NumExpr] roundProducts cfg frontier base = - [c | e1 <- frontier, e2 <- base, not (numExprEq e1 e2), not (isDisallowed cfg e1 e2), c <- combineNumExprs e1 e2] + [ c + | e1 <- frontier + , e2 <- base + , not (numExprEq e1 e2) + , not (isDisallowed cfg e1 e2) + , c <- combineNumExprs e1 e2 + ] -expandRounds :: SynthConfig -> [NumExpr] -> Int -> [NumExpr] -> Set.Set String -> [NumExpr] +expandRounds :: + SynthConfig -> [NumExpr] -> Int -> [NumExpr] -> Set.Set String -> [NumExpr] expandRounds _ _ 0 _ _ = [] expandRounds cfg base d frontier seen | null admitted = [] diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Pool.hs b/dataframe-learn/src/DataFrame/DecisionTree/Pool.hs index de59d84..828f3c7 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Pool.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Pool.hs @@ -1,9 +1,10 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE OverloadedStrings #-} --- | Candidate-pool scoring and boolean expansion: penalized scoring, diverse --- top-K selection, AND/OR saturation, and structural/truth-vector dedup. The --- per-node scoring scans run in parallel chunks. +{- | Candidate-pool scoring and boolean expansion: penalized scoring, diverse +top-K selection, AND/OR saturation, and structural/truth-vector dedup. The +per-node scoring scans run in parallel chunks. +-} module DataFrame.DecisionTree.Pool ( evalWithPenaltyVec, primaryColExpr, @@ -21,9 +22,25 @@ module DataFrame.DecisionTree.Pool ( nubByExpr, ) where -import DataFrame.DecisionTree.CondVec (CondVec (..), combineAndVec, combineOrVec, countErrorsByVec) -import DataFrame.DecisionTree.Types (CarePoint, SynthConfig (..), TreeConfig (..)) -import DataFrame.Internal.Expression (Expr, compareExpr, eSize, eqExpr, getColumns, normalize) +import DataFrame.DecisionTree.CondVec ( + CondVec (..), + combineAndVec, + combineOrVec, + countErrorsByVec, + ) +import DataFrame.DecisionTree.Types ( + CarePoint, + SynthConfig (..), + TreeConfig (..), + ) +import DataFrame.Internal.Expression ( + Expr, + compareExpr, + eSize, + eqExpr, + getColumns, + normalize, + ) import Control.Parallel.Strategies (parListChunk, rdeepseq, using) import Data.Function (on) @@ -33,16 +50,18 @@ import qualified Data.Set as Set import qualified Data.Text as T import qualified Data.Vector.Unboxed as VU --- | Penalized score of a candidate: care-point errors plus a complexity --- penalty, tie-broken by expression size. +{- | Penalized score of a candidate: care-point errors plus a complexity +penalty, tie-broken by expression size. +-} evalWithPenaltyVec :: TreeConfig -> [CarePoint] -> CondVec -> (Int, Int) evalWithPenaltyVec cfg carePoints cv = (countErrorsByVec (cvVec cv) carePoints + penalty, sz) where sz = eSize (cvExpr cv) penalty = floor (complexityPenalty (synthConfig cfg) * fromIntegral sz) --- | First referenced column of a condition (a sentinel for literal-only ones), --- used by 'takeDiverse' to enforce per-column diversity. +{- | First referenced column of a condition (a sentinel for literal-only ones), +used by 'takeDiverse' to enforce per-column diversity. +-} primaryColExpr :: Expr Bool -> T.Text primaryColExpr e = case getColumns e of [] -> "" @@ -51,8 +70,9 @@ primaryColExpr e = case getColumns e of primaryColCV :: CondVec -> T.Text primaryColCV = primaryColExpr . cvExpr --- | Keep the first @k@ of an already-sorted list, admitting at most @quota@ per --- primary column (@Nothing@ disables the per-column cap). +{- | Keep the first @k@ of an already-sorted list, admitting at most @quota@ per +primary column (@Nothing@ disables the per-column cap). +-} takeDiverse :: Int -> Maybe Int -> (a -> T.Text) -> [a] -> [a] takeDiverse k Nothing _ = take k takeDiverse k (Just quota) primary = go M.empty 0 @@ -65,36 +85,51 @@ takeDiverse k (Just quota) primary = go M.empty 0 where !col = primary x --- | Chunk size for the parallel per-node candidate scans; tuned by an -N --- sweep, not correctness-affecting. +{- | Chunk size for the parallel per-node candidate scans; tuned by an -N +sweep, not correctness-affecting. +-} candidateParChunk :: Int candidateParChunk = 64 --- | Decorate candidates with their penalty in parallel chunks, forcing only --- the @(Int, Int)@ key so the order (hence later sorts/minima) is preserved. +{- | Decorate candidates with their penalty in parallel chunks, forcing only +the @(Int, Int)@ key so the order (hence later sorts/minima) is preserved. +-} decorate :: (CondVec -> (Int, Int)) -> [CondVec] -> [((Int, Int), CondVec)] decorate penaltyCV xs = zip (map penaltyCV xs `using` parListChunk candidateParChunk rdeepseq) xs -- | The diverse top-@expressionPairs@ valid candidates by penalty. sortedTopK :: TreeConfig -> (CondVec -> (Int, Int)) -> [CondVec] -> [CondVec] sortedTopK cfg penaltyCV validCondVecs = - map snd (takeDiverse (expressionPairs cfg) (perColumnQuota (synthConfig cfg)) (primaryColCV . snd) sorted) + map + snd + ( takeDiverse + (expressionPairs cfg) + (perColumnQuota (synthConfig cfg)) + (primaryColCV . snd) + sorted + ) where sorted = sortBy (compare `on` fst) (decorate penaltyCV validCondVecs) -- | Lowest-penalty candidate after boolean saturation of the diverse top-K. -bestDiscreteCandidate :: TreeConfig -> (CondVec -> (Int, Int)) -> [CondVec] -> Maybe CondVec +bestDiscreteCandidate :: + TreeConfig -> (CondVec -> (Int, Int)) -> [CondVec] -> Maybe CondVec bestDiscreteCandidate _ _ [] = Nothing bestDiscreteCandidate cfg penaltyCV validCondVecs = - case saturateCandidates Structural (boolExpansion (synthConfig cfg)) (sortedTopK cfg penaltyCV validCondVecs) of + case saturateCandidates + Structural + (boolExpansion (synthConfig cfg)) + (sortedTopK cfg penaltyCV validCondVecs) of [] -> Nothing xs -> Just (snd (minimumBy (compare `on` fst) (decorate penaltyCV xs))) --- | AND/OR expansion of cached conditions to depth @maxDepth@ (each --- combination is a single vector op, not an interpret). +{- | AND/OR expansion of cached conditions to depth @maxDepth@ (each +combination is a single vector op, not an interpret). +-} boolExprsVec :: [CondVec] -> [CondVec] -> Int -> Int -> [CondVec] boolExprsVec baseExprs prevExprs depth maxDepth - | depth == 0 = baseExprs ++ boolExprsVec baseExprs prevExprs (depth + 1) maxDepth + | depth == 0 = + baseExprs ++ boolExprsVec baseExprs prevExprs (depth + 1) maxDepth | depth >= maxDepth = [] | otherwise = combined ++ boolExprsVec baseExprs combined (depth + 1) maxDepth where @@ -103,27 +138,38 @@ boolExprsVec baseExprs prevExprs depth maxDepth data DedupMode = Structural | TruthVector deriving (Eq, Show) --- | Saturate the pool with AND/OR combinations, deduplicating structurally --- (byte-identical, first occurrence kept) or by truth vector (opt-in). +{- | Saturate the pool with AND/OR combinations, deduplicating structurally +(byte-identical, first occurrence kept) or by truth vector (opt-in). +-} saturateCandidates :: DedupMode -> Int -> [CondVec] -> [CondVec] saturateCandidates Structural maxDepth base = base' ++ go 1 base' seen0 where (base', seen0) = admitKeys Set.empty base go !depth frontier seen | depth >= maxDepth || null frontier = [] - | otherwise = let (admitted, seen') = admitKeys seen (roundProducts frontier base) in admitted ++ go (depth + 1) admitted seen' + | otherwise = + let (admitted, seen') = admitKeys seen (roundProducts frontier base) + in admitted ++ go (depth + 1) admitted seen' saturateCandidates TruthVector maxDepth base = M.elems (go 1 frontier0 reps0) where (reps0, frontier0) = admitVecs M.empty base go !depth frontier reps | depth >= maxDepth || null frontier = reps - | otherwise = let (reps', admitted) = admitVecs reps (roundProducts frontier base) in go (depth + 1) admitted reps' + | otherwise = + let (reps', admitted) = admitVecs reps (roundProducts frontier base) + in go (depth + 1) admitted reps' --- | One combination round: @frontier × base@ via AND then OR, skipping --- self-pairs (mirrors 'boolExprsVec' for byte-identical structural output). +{- | One combination round: @frontier × base@ via AND then OR, skipping +self-pairs (mirrors 'boolExprsVec' for byte-identical structural output). +-} roundProducts :: [CondVec] -> [CondVec] -> [CondVec] roundProducts frontier base = - [c | e1 <- frontier, e2 <- base, not (eqExpr (cvExpr e1) (cvExpr e2)), c <- [combineAndVec e1 e2, combineOrVec e1 e2]] + [ c + | e1 <- frontier + , e2 <- base + , not (eqExpr (cvExpr e1) (cvExpr e2)) + , c <- [combineAndVec e1 e2, combineOrVec e1 e2] + ] -- | Admit candidates with a not-yet-seen normalized form, preserving order. admitKeys :: Set.Set String -> [CondVec] -> ([CondVec], Set.Set String) @@ -137,9 +183,13 @@ admitKeys = go [] structuralKey :: CondVec -> String structuralKey = show . normalize . cvExpr --- | Admit candidates by distinct truth vector, keeping the smallest-expression --- representative per vector. -admitVecs :: M.Map (VU.Vector Bool) CondVec -> [CondVec] -> (M.Map (VU.Vector Bool) CondVec, [CondVec]) +{- | Admit candidates by distinct truth vector, keeping the smallest-expression +representative per vector. +-} +admitVecs :: + M.Map (VU.Vector Bool) CondVec -> + [CondVec] -> + (M.Map (VU.Vector Bool) CondVec, [CondVec]) admitVecs = go [] where go acc reps [] = (reps, reverse acc) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Predict.hs b/dataframe-learn/src/DataFrame/DecisionTree/Predict.hs index d9c20cb..e1bcb23 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Predict.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Predict.hs @@ -2,9 +2,10 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Prediction, care-point identification, node validity, and tree loss. The --- batched, cache-aware variants resolve each branch condition's truth vector --- once per call instead of once per row. +{- | Prediction, care-point identification, node validity, and tree loss. The +batched, cache-aware variants resolve each branch condition's truth vector +once per call instead of once per row. +-} module DataFrame.DecisionTree.Predict ( predictWithTree, predictManyWithTree, @@ -20,8 +21,17 @@ module DataFrame.DecisionTree.Predict ( isValidAtNode, ) where -import DataFrame.DecisionTree.CondVec (CondCache, countErrorsByVec, lookupCondVec) -import DataFrame.DecisionTree.Types (CarePoint (..), Direction (..), Tree (..), TreeConfig (..)) +import DataFrame.DecisionTree.CondVec ( + CondCache, + countErrorsByVec, + lookupCondVec, + ) +import DataFrame.DecisionTree.Types ( + CarePoint (..), + Direction (..), + Tree (..), + TreeConfig (..), + ) import DataFrame.Internal.Column (Columnable, TypedColumn (..), toVector) import DataFrame.Internal.DataFrame (DataFrame) import DataFrame.Internal.Expression (Expr (..)) @@ -37,21 +47,24 @@ import qualified Data.Vector as V import qualified Data.Vector.Mutable as VM import qualified Data.Vector.Unboxed as VU --- | A condition's truth vector over the DataFrame, or 'Nothing' on a --- type/interpret failure (callers default such rows to the left child). +{- | A condition's truth vector over the DataFrame, or 'Nothing' on a +type/interpret failure (callers default such rows to the left child). +-} branchBool :: DataFrame -> Expr Bool -> Maybe (VU.Vector Bool) branchBool df cond = case interpret @Bool df cond of Right (TColumn column) -> either (const Nothing) Just (toVector @Bool @VU.Vector column) _ -> Nothing -- | The target column as a label vector, or 'Nothing' on failure. -interpretLabelCol :: forall a. (Columnable a) => DataFrame -> T.Text -> Maybe (V.Vector a) +interpretLabelCol :: + forall a. (Columnable a) => DataFrame -> T.Text -> Maybe (V.Vector a) interpretLabelCol df target = case interpret @a df (Col target) of Right (TColumn column) -> either (const Nothing) Just (toVector @a column) _ -> Nothing -- | Predict the label for a single row by walking a fixed tree (@True@ → left). -predictWithTree :: forall a. (Columnable a) => T.Text -> DataFrame -> Int -> Tree a -> a +predictWithTree :: + forall a. (Columnable a) => T.Text -> DataFrame -> Int -> Tree a -> a predictWithTree _ _ _ (Leaf v) = v predictWithTree target df idx (Branch cond left right) = predictWithTree @a target df idx (childFor cond left right idx df) @@ -61,12 +74,16 @@ childFor cond left right idx df = case branchBool df cond of Nothing -> left Just boolVals -> if boolVals VU.! idx then left else right -predictManyWithTree :: forall a. (Columnable a) => Tree a -> DataFrame -> V.Vector Int -> V.Vector a +predictManyWithTree :: + forall a. (Columnable a) => Tree a -> DataFrame -> V.Vector Int -> V.Vector a predictManyWithTree = predictManyWithTreeCached @a M.empty --- | 'predictManyWithTree' resolving each branch condition through a 'CondCache'. --- Each condition is read at most once per call rather than once per row. -predictManyWithTreeCached :: forall a. (Columnable a) => CondCache -> Tree a -> DataFrame -> V.Vector Int -> V.Vector a +{- | 'predictManyWithTree' resolving each branch condition through a 'CondCache'. +Each condition is read at most once per call rather than once per row. +-} +predictManyWithTreeCached :: + forall a. + (Columnable a) => CondCache -> Tree a -> DataFrame -> V.Vector Int -> V.Vector a predictManyWithTreeCached cache tree df indices = V.create $ do mv <- VM.new (V.length indices) fill mv (V.zip (V.enumFromN 0 (V.length indices)) indices) tree @@ -78,15 +95,33 @@ predictManyWithTreeCached cache tree df indices = V.create $ do Nothing -> fill mv prs left Just boolVals -> fillSplit mv (V.partition (\(_, i) -> boolVals VU.! i) prs) left right - fillSplit :: VM.MVector s a -> (V.Vector (Int, Int), V.Vector (Int, Int)) -> Tree a -> Tree a -> ST s () + fillSplit :: + VM.MVector s a -> + (V.Vector (Int, Int), V.Vector (Int, Int)) -> + Tree a -> + Tree a -> + ST s () fillSplit mv (leftPrs, rightPrs) left right = fill mv leftPrs left >> fill mv rightPrs right -identifyCarePoints :: forall a. (Columnable a) => T.Text -> DataFrame -> V.Vector Int -> Tree a -> Tree a -> [CarePoint] +identifyCarePoints :: + forall a. + (Columnable a) => + T.Text -> DataFrame -> V.Vector Int -> Tree a -> Tree a -> [CarePoint] identifyCarePoints = identifyCarePointsCached @a M.empty --- | Rows the parent must route to a specific child for the (fixed) subtrees to --- classify correctly; a 'CondCache' avoids re-interpreting subtree conditions. -identifyCarePointsCached :: forall a. (Columnable a) => CondCache -> T.Text -> DataFrame -> V.Vector Int -> Tree a -> Tree a -> [CarePoint] +{- | Rows the parent must route to a specific child for the (fixed) subtrees to +classify correctly; a 'CondCache' avoids re-interpreting subtree conditions. +-} +identifyCarePointsCached :: + forall a. + (Columnable a) => + CondCache -> + T.Text -> + DataFrame -> + V.Vector Int -> + Tree a -> + Tree a -> + [CarePoint] identifyCarePointsCached cache target df indices leftTree rightTree = maybe [] carePoints (interpretLabelCol @a df target) where @@ -94,7 +129,9 @@ identifyCarePointsCached cache target df indices leftTree rightTree = rightPreds = predictManyWithTreeCached cache rightTree df indices carePoints targetVals = V.toList (V.imapMaybe (checkPoint targetVals leftPreds rightPreds) indices) -checkPoint :: (Eq a) => V.Vector a -> V.Vector a -> V.Vector a -> Int -> Int -> Maybe CarePoint +checkPoint :: + (Eq a) => + V.Vector a -> V.Vector a -> V.Vector a -> Int -> Int -> Maybe CarePoint checkPoint targetVals leftPreds rightPreds k idx = case (leftPreds V.! k == trueLabel, rightPreds V.! k == trueLabel) of (True, False) -> Just (CarePoint idx GoLeft) @@ -108,12 +145,19 @@ countCarePointErrors :: Expr Bool -> DataFrame -> [CarePoint] -> Int countCarePointErrors cond df carePoints = maybe (length carePoints) (`countErrorsByVec` carePoints) (branchBool df cond) -partitionIndices :: Expr Bool -> DataFrame -> V.Vector Int -> (V.Vector Int, V.Vector Int) +partitionIndices :: + Expr Bool -> DataFrame -> V.Vector Int -> (V.Vector Int, V.Vector Int) partitionIndices = partitionIndicesCached M.empty --- | 'partitionIndices' resolving the condition through a 'CondCache'; a miss --- routes every index left (matching the uncached fallback). -partitionIndicesCached :: CondCache -> Expr Bool -> DataFrame -> V.Vector Int -> (V.Vector Int, V.Vector Int) +{- | 'partitionIndices' resolving the condition through a 'CondCache'; a miss +routes every index left (matching the uncached fallback). +-} +partitionIndicesCached :: + CondCache -> + Expr Bool -> + DataFrame -> + V.Vector Int -> + (V.Vector Int, V.Vector Int) partitionIndicesCached cache cond df indices = case lookupCondVec cache df cond of Nothing -> (indices, V.empty) Just boolVals -> V.partition (boolVals VU.!) indices @@ -125,7 +169,8 @@ isValidAtNode cfg df indices c = where (t, f) = partitionIndices c df indices -majorityValueFromIndices :: forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> V.Vector Int -> a +majorityValueFromIndices :: + forall a. (Columnable a, Ord a) => T.Text -> DataFrame -> V.Vector Int -> a majorityValueFromIndices target df indices = majorityOf (countLabels (labelColOrThrow @a df target) indices) labelColOrThrow :: forall a. (Columnable a) => DataFrame -> T.Text -> V.Vector a @@ -141,21 +186,31 @@ majorityOf counts | M.null counts = error "Empty indices in majorityValueFromIndices" | otherwise = fst (maximumBy (compare `on` snd) (M.toList counts)) -computeTreeLoss :: forall a. (Columnable a) => T.Text -> DataFrame -> V.Vector Int -> Tree a -> Double +computeTreeLoss :: + forall a. + (Columnable a) => T.Text -> DataFrame -> V.Vector Int -> Tree a -> Double computeTreeLoss = computeTreeLossCached @a M.empty -- | 0/1 loss of a tree over @indices@, with a 'CondCache' for the predictions. -computeTreeLossCached :: forall a. (Columnable a) => CondCache -> T.Text -> DataFrame -> V.Vector Int -> Tree a -> Double +computeTreeLossCached :: + forall a. + (Columnable a) => + CondCache -> T.Text -> DataFrame -> V.Vector Int -> Tree a -> Double computeTreeLossCached cache target df indices tree | V.null indices = 0 - | otherwise = maybe 1.0 (treeLoss cache tree df indices) (interpretLabelCol @a df target) + | otherwise = + maybe 1.0 (treeLoss cache tree df indices) (interpretLabelCol @a df target) -treeLoss :: (Columnable a) => CondCache -> Tree a -> DataFrame -> V.Vector Int -> V.Vector a -> Double +treeLoss :: + (Columnable a) => + CondCache -> Tree a -> DataFrame -> V.Vector Int -> V.Vector a -> Double treeLoss cache tree df indices targetVals = - fromIntegral (countMismatches targetVals indices preds) / fromIntegral (V.length indices) + fromIntegral (countMismatches targetVals indices preds) + / fromIntegral (V.length indices) where preds = predictManyWithTreeCached cache tree df indices countMismatches :: (Eq a) => V.Vector a -> V.Vector Int -> V.Vector a -> Int countMismatches targetVals indices preds = - V.length (V.ifilter (\k _ -> targetVals V.! (indices V.! k) /= preds V.! k) preds) + V.length + (V.ifilter (\k _ -> targetVals V.! (indices V.! k) /= preds V.! k) preds) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Prune.hs b/dataframe-learn/src/DataFrame/DecisionTree/Prune.hs index a090aad..3372462 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Prune.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Prune.hs @@ -2,9 +2,10 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} --- | Post-convergence simplification of a fitted tree and its expression form: --- drop branches forced by path-condition entailment, collapse identical --- siblings, and fold redundant nested conditionals. +{- | Post-convergence simplification of a fitted tree and its expression form: +drop branches forced by path-condition entailment, collapse identical +siblings, and fold redundant nested conditionals. +-} module DataFrame.DecisionTree.Prune ( pruneDead, treeEq, @@ -16,9 +17,10 @@ import DataFrame.Internal.Column (Columnable) import DataFrame.Internal.Expression (Expr (..), eqExpr) import DataFrame.Internal.Simplify (PredFact, entails, factFalse, factTrue) --- | Drop branches whose test is forced by the path conditions reaching them, --- and collapse @Branch c t t@ to @t@. Sound for the decidable threshold subset; --- other tests are left untouched. +{- | Drop branches whose test is forced by the path conditions reaching them, +and collapse @Branch c t t@ to @t@. Sound for the decidable threshold subset; +other tests are left untouched. +-} pruneDead :: forall a. (Columnable a) => Tree a -> Tree a pruneDead = go [] where @@ -27,7 +29,11 @@ pruneDead = go [] go facts (Branch cond left right) = case entails facts cond of Just True -> go facts left Just False -> go facts right - Nothing -> reconcile cond (go (addFact (factTrue cond) facts) left) (go (addFact (factFalse cond) facts) right) + Nothing -> + reconcile + cond + (go (addFact (factTrue cond) facts) left) + (go (addFact (factFalse cond) facts) right) reconcile :: (Columnable a) => Expr Bool -> Tree a -> Tree a -> Tree a reconcile cond left right @@ -43,8 +49,9 @@ treeEq (Leaf x) (Leaf y) = x == y treeEq (Branch c1 l1 r1) (Branch c2 l2 r2) = eqExpr c1 c2 && treeEq l1 l2 && treeEq r1 r2 treeEq _ _ = False --- | Recursively fold @If@ expressions whose branches coincide or nest the same --- condition; leave other expressions structurally unchanged. +{- | Recursively fold @If@ expressions whose branches coincide or nest the same +condition; leave other expressions structurally unchanged. +-} pruneExpr :: forall a. (Columnable a) => Expr a -> Expr a pruneExpr (If cond t0 f0) = collapseIf cond (pruneExpr t0) (pruneExpr f0) pruneExpr (Unary op e) = Unary op (pruneExpr e) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Tao.hs b/dataframe-learn/src/DataFrame/DecisionTree/Tao.hs index da693f0..f12a6cd 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Tao.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Tao.hs @@ -3,9 +3,10 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Tree Alternating Optimization: hold the tree fixed and re-optimize one node --- at a time, bottom-up, minimizing care-point misroutes. Sibling subtrees at a --- depth level are independent and optimized in parallel. +{- | Tree Alternating Optimization: hold the tree fixed and re-optimize one node +at a time, bottom-up, minimizing care-point misroutes. Sibling subtrees at a +depth level are independent and optimized in parallel. +-} module DataFrame.DecisionTree.Tao ( taoOptimize, taoOptimizeCV, @@ -17,7 +18,11 @@ module DataFrame.DecisionTree.Tao ( import DataFrame.DecisionTree.CondVec import DataFrame.DecisionTree.Linear (bestLinearCandidate) -import DataFrame.DecisionTree.Pool (bestDiscreteCandidate, candidateParChunk, evalWithPenaltyVec) +import DataFrame.DecisionTree.Pool ( + bestDiscreteCandidate, + candidateParChunk, + evalWithPenaltyVec, + ) import DataFrame.DecisionTree.Predict import DataFrame.DecisionTree.Prune (pruneDead) import DataFrame.DecisionTree.Types @@ -34,8 +39,9 @@ import qualified Data.Text as T import qualified Data.Vector as V import qualified Data.Vector.Unboxed as VU --- | The constant per-fit context threaded through the node-optimization --- recursion (the cache is rebuilt each iteration). +{- | The constant per-fit context threaded through the node-optimization +recursion (the cache is rebuilt each iteration). +-} data TaoEnv = TaoEnv { teCache :: !CondCache , teCfg :: !TreeConfig @@ -45,13 +51,32 @@ data TaoEnv = TaoEnv } -- | Public TAO entry point over raw conditions; materializes each once. -taoOptimize :: forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> V.Vector Int -> Tree a -> Tree a +taoOptimize :: + forall a. + (Columnable a, Ord a) => + TreeConfig -> + T.Text -> + [Expr Bool] -> + DataFrame -> + V.Vector Int -> + Tree a -> + Tree a taoOptimize cfg target conds df = taoOptimizeCV @a cfg target (mapMaybe (materializeCondVec df) conds) df --- | TAO outer loop over pre-evaluated candidates: iterate until the iteration --- budget or convergence tolerance is reached, then prune dead branches. -taoOptimizeCV :: forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> [CondVec] -> DataFrame -> V.Vector Int -> Tree a -> Tree a +{- | TAO outer loop over pre-evaluated candidates: iterate until the iteration +budget or convergence tolerance is reached, then prune dead branches. +-} +taoOptimizeCV :: + forall a. + (Columnable a, Ord a) => + TreeConfig -> + T.Text -> + [CondVec] -> + DataFrame -> + V.Vector Int -> + Tree a -> + Tree a taoOptimizeCV cfg target condVecs df rootIndices initialTree = go 0 initialTree (lossWith baseCache initialTree) where @@ -67,32 +92,63 @@ taoOptimizeCV cfg target condVecs df rootIndices initialTree = newLoss = lossWith cache tree' -- | Public single-iteration entry point. -taoIteration :: forall a. (Columnable a, Ord a) => TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> V.Vector Int -> Tree a -> Tree a +taoIteration :: + forall a. + (Columnable a, Ord a) => + TreeConfig -> + T.Text -> + [Expr Bool] -> + DataFrame -> + V.Vector Int -> + Tree a -> + Tree a taoIteration cfg target conds df rootIndices tree = let condVecs = mapMaybe (materializeCondVec df) conds cache = addTreeCondsToCache df tree (condCacheFromVecs condVecs) in taoIterationCV @a cache cfg target condVecs df rootIndices tree -- | One bottom-to-top sweep: re-optimize every node level by level. -taoIterationCV :: forall a. (Columnable a, Ord a) => CondCache -> TreeConfig -> T.Text -> [CondVec] -> DataFrame -> V.Vector Int -> Tree a -> Tree a +taoIterationCV :: + forall a. + (Columnable a, Ord a) => + CondCache -> + TreeConfig -> + T.Text -> + [CondVec] -> + DataFrame -> + V.Vector Int -> + Tree a -> + Tree a taoIterationCV cache cfg target condVecs df rootIndices tree = - foldl' (optimizeDepthLevel env rootIndices) tree [treeDepth tree, treeDepth tree - 1 .. 0] + foldl' + (optimizeDepthLevel env rootIndices) + tree + [treeDepth tree, treeDepth tree - 1 .. 0] where env = TaoEnv cache cfg target condVecs df -optimizeDepthLevel :: forall a. (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Tree a -> Int -> Tree a +optimizeDepthLevel :: + forall a. + (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Tree a -> Int -> Tree a optimizeDepthLevel env rootIndices tree = optimizeAtDepth @a env rootIndices tree 0 -optimizeAtDepth :: forall a. (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Tree a -> Int -> Int -> Tree a +optimizeAtDepth :: + forall a. + (Columnable a, Ord a) => + TaoEnv -> V.Vector Int -> Tree a -> Int -> Int -> Tree a optimizeAtDepth env indices tree currentDepth targetDepth | currentDepth == targetDepth = optimizeNode @a env indices tree | otherwise = case tree of Leaf v -> Leaf v Branch cond left right -> optimizeChildren @a env indices cond left right currentDepth targetDepth --- | Optimize the two subtrees over their disjoint index sets, scoring the left --- in parallel with the right (the cache is read-only, so this is pure). -optimizeChildren :: forall a. (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Expr Bool -> Tree a -> Tree a -> Int -> Int -> Tree a +{- | Optimize the two subtrees over their disjoint index sets, scoring the left +in parallel with the right (the cache is read-only, so this is pure). +-} +optimizeChildren :: + forall a. + (Columnable a, Ord a) => + TaoEnv -> V.Vector Int -> Expr Bool -> Tree a -> Tree a -> Int -> Int -> Tree a optimizeChildren env indices cond left right currentDepth targetDepth = forceTreeWork left' `par` (forceTreeWork right' `pseq` Branch cond left' right') where @@ -100,15 +156,18 @@ optimizeChildren env indices cond left right currentDepth targetDepth = left' = optimizeAtDepth @a env indicesL left (currentDepth + 1) targetDepth right' = optimizeAtDepth @a env indicesR right (currentDepth + 1) targetDepth --- | Force a subtree's optimization work to WHNF so the parallel scheduler has --- something substantial to evaluate; pure and value-preserving. +{- | Force a subtree's optimization work to WHNF so the parallel scheduler has +something substantial to evaluate; pure and value-preserving. +-} forceTreeWork :: Tree a -> () forceTreeWork (Leaf v) = v `seq` () forceTreeWork (Branch c l r) = c `seq` forceTreeWork l `seq` forceTreeWork r --- | Re-optimize one node: pick its best split, or collapse to a leaf when the --- node is empty or the chosen split underflows 'minLeafSize'. -optimizeNode :: forall a. (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Tree a -> Tree a +{- | Re-optimize one node: pick its best split, or collapse to a leaf when the +node is empty or the chosen split underflows 'minLeafSize'. +-} +optimizeNode :: + forall a. (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Tree a -> Tree a optimizeNode env indices tree | V.null indices = tree | otherwise = case tree of @@ -117,7 +176,10 @@ optimizeNode env indices tree where leaf = Leaf (majorityValueFromIndices @a (teTarget env) (teDf env) indices) -rebuiltBranch :: forall a. (Columnable a, Ord a) => TaoEnv -> V.Vector Int -> Expr Bool -> Tree a -> Tree a -> Tree a -> Tree a +rebuiltBranch :: + forall a. + (Columnable a, Ord a) => + TaoEnv -> V.Vector Int -> Expr Bool -> Tree a -> Tree a -> Tree a -> Tree a rebuiltBranch env indices oldCond left right leaf | underflows = leaf | otherwise = Branch newCond left right @@ -126,43 +188,79 @@ rebuiltBranch env indices oldCond left right leaf (l, r) = partitionIndicesCached (teCache env) newCond (teDf env) indices underflows = V.length l < minLeafSize (teCfg env) || V.length r < minLeafSize (teCfg env) --- | The lowest-penalty replacement condition for a node, falling back to the --- current condition when no valid candidate beats it. -findBestSplitTAO :: forall a. (Columnable a) => TaoEnv -> V.Vector Int -> Tree a -> Tree a -> Expr Bool -> Expr Bool +{- | The lowest-penalty replacement condition for a node, falling back to the +current condition when no valid candidate beats it. +-} +findBestSplitTAO :: + forall a. + (Columnable a) => + TaoEnv -> V.Vector Int -> Tree a -> Tree a -> Expr Bool -> Expr Bool findBestSplitTAO env indices leftTree rightTree currentCond | V.null indices || null carePoints = currentCond - | pureReplacementLinear cfg, Just c <- linearCandidate, isValidAtNode cfg (teDf env) indices c = c + | pureReplacementLinear cfg + , Just c <- linearCandidate + , isValidAtNode cfg (teDf env) indices c = + c | otherwise = bestOfPool penaltyCV currentCond pool where cfg = teCfg env - carePoints = identifyCarePointsCached @a (teCache env) (teTarget env) (teDf env) indices leftTree rightTree + carePoints = + identifyCarePointsCached @a + (teCache env) + (teTarget env) + (teDf env) + indices + leftTree + rightTree penaltyCV = evalWithPenaltyVec cfg carePoints linearCandidate = bestLinearCandidate cfg (teDf env) carePoints valid = filterValidCandidates cfg indices (teConds env) - pool = candidatePool env indices currentCond (bestDiscreteCandidate cfg penaltyCV valid) linearCandidate + pool = + candidatePool + env + indices + currentCond + (bestDiscreteCandidate cfg penaltyCV valid) + linearCandidate bestOfPool :: (CondVec -> (Int, Int)) -> Expr Bool -> [CondVec] -> Expr Bool bestOfPool _ currentCond [] = currentCond bestOfPool penaltyCV _ pool = cvExpr (minimumBy (compare `on` penaltyCV) pool) --- | Validity-filtered candidates the node could split on: both children must --- keep at least 'minLeafSize'. Scored in parallel chunks, order preserved. +{- | Validity-filtered candidates the node could split on: both children must +keep at least 'minLeafSize'. Scored in parallel chunks, order preserved. +-} filterValidCandidates :: TreeConfig -> V.Vector Int -> [CondVec] -> [CondVec] filterValidCandidates cfg indices condVecs = map snd (filter fst (zip validity condVecs)) where - validity = map (validAtNode cfg indices) condVecs `using` parListChunk candidateParChunk rdeepseq + validity = + map (validAtNode cfg indices) condVecs + `using` parListChunk candidateParChunk rdeepseq validAtNode :: TreeConfig -> V.Vector Int -> CondVec -> Bool validAtNode cfg indices cv = nTrue >= minLeaf && (V.length indices - nTrue) >= minLeaf where minLeaf = minLeafSize cfg - nTrue = V.foldl' (\ !acc i -> if cvVec cv VU.! i then acc + 1 else acc) (0 :: Int) indices + nTrue = + V.foldl' + (\ !acc i -> if cvVec cv VU.! i then acc + 1 else acc) + (0 :: Int) + indices --- | The candidate pool to minimize over: the current condition, the best --- discrete candidate, and the linear candidate, each kept only if valid. -candidatePool :: TaoEnv -> V.Vector Int -> Expr Bool -> Maybe CondVec -> Maybe (Expr Bool) -> [CondVec] +{- | The candidate pool to minimize over: the current condition, the best +discrete candidate, and the linear candidate, each kept only if valid. +-} +candidatePool :: + TaoEnv -> + V.Vector Int -> + Expr Bool -> + Maybe CondVec -> + Maybe (Expr Bool) -> + [CondVec] candidatePool env indices currentCond discreteCV linearCandidate = - filter (isValidAtNode (teCfg env) (teDf env) indices . cvExpr) (catMaybes [currentCV, discreteCV, linearCV]) + filter + (isValidAtNode (teCfg env) (teDf env) indices . cvExpr) + (catMaybes [currentCV, discreteCV, linearCV]) where currentCV = CondVec currentCond <$> lookupCondVec (teCache env) (teDf env) currentCond linearCV = linearCandidate >>= materializeCondVec (teDf env) diff --git a/dataframe-learn/src/DataFrame/DecisionTree/Types.hs b/dataframe-learn/src/DataFrame/DecisionTree/Types.hs index 349e21d..b97a1b6 100644 --- a/dataframe-learn/src/DataFrame/DecisionTree/Types.hs +++ b/dataframe-learn/src/DataFrame/DecisionTree/Types.hs @@ -5,8 +5,9 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} --- | Shared types, configuration and ordering machinery for the decision-tree --- learner. Imported by every other @DataFrame.DecisionTree.*@ module. +{- | Shared types, configuration and ordering machinery for the decision-tree +learner. Imported by every other @DataFrame.DecisionTree.*@ module. +-} module DataFrame.DecisionTree.Types ( Tree (..), treeDepth, @@ -38,8 +39,9 @@ import System.Environment (lookupEnv) import System.IO.Unsafe (unsafePerformIO) import Type.Reflection (SomeTypeRep (..), typeRep) --- | A fitted tree: a leaf value, or an internal node testing a boolean --- expression with @True@ routing left. +{- | A fitted tree: a leaf value, or an internal node testing a boolean +expression with @True@ routing left. +-} data Tree a = Leaf !a | Branch !(Expr Bool) !(Tree a) !(Tree a) @@ -49,8 +51,9 @@ treeDepth :: Tree a -> Int treeDepth (Leaf _) = 0 treeDepth (Branch _ l r) = 1 + max (treeDepth l) (treeDepth r) --- | A row the parent node must route to a specific child for the subtrees to --- classify it correctly (the TAO objective is the count of misroutes). +{- | A row the parent node must route to a specific child for the subtrees to +classify it correctly (the TAO objective is the count of misroutes). +-} data CarePoint = CarePoint { cpIndex :: !Int , cpCorrectDir :: !Direction @@ -121,8 +124,9 @@ defaultTreeConfig = , pureReplacementLinear = False } --- | Which column types support ordering for splits. Register a type with --- 'orderable' and combine with @<>@. +{- | Which column types support ordering for splits. Register a type with +'orderable' and combine with @<>@. +-} newtype ColumnOrdering = ColumnOrdering (M.Map SomeTypeRep OrdDict) instance Semigroup ColumnOrdering where @@ -164,8 +168,9 @@ otherOrderings = data OrdDict where OrdDict :: (Columnable a, Ord a) => Proxy a -> OrdDict --- | Run @k@ with the @Ord a@ instance recovered from the ordering registry, --- or 'Nothing' when @a@ is not registered. +{- | Run @k@ with the @Ord a@ instance recovered from the ordering registry, +or 'Nothing' when @a@ is not registered. +-} withOrdFrom :: forall a r. (Columnable a) => ColumnOrdering -> ((Ord a) => r) -> Maybe r withOrdFrom (ColumnOrdering m) k = case M.lookup (SomeTypeRep (typeRep @a)) m of diff --git a/dataframe-parquet/src/DataFrame/IO/Parquet.hs b/dataframe-parquet/src/DataFrame/IO/Parquet.hs index bcdc983..b02593f 100644 --- a/dataframe-parquet/src/DataFrame/IO/Parquet.hs +++ b/dataframe-parquet/src/DataFrame/IO/Parquet.hs @@ -402,7 +402,7 @@ getNonNullableColumn totalRows description chunks = go decoder = foldNonNullable totalRows $ (\(vs, _, _) -> vs) - <$> Stream.unfoldEach (readPages description decoder) (Stream.fromList chunks) + <$> Stream.unfoldMany (readPages description decoder) (Stream.fromList chunks) unboxedGo :: forall a. @@ -412,7 +412,7 @@ getNonNullableColumn totalRows description chunks = unboxedGo decoder = foldNonNullableUnboxed totalRows $ (\(vs, _, _) -> vs) - <$> Stream.unfoldEach + <$> Stream.unfoldMany (readPages description decoder) (Stream.fromList chunks) @@ -449,7 +449,7 @@ getNullableColumn totalRows description chunks = go decoder = foldNullable maxDef totalRows $ (\(vs, ds, _) -> (vs, ds)) - <$> Stream.unfoldEach (readPages description decoder) (Stream.fromList chunks) + <$> Stream.unfoldMany (readPages description decoder) (Stream.fromList chunks) unboxedGo :: forall a. (Columnable a, VU.Unbox a) => @@ -458,7 +458,7 @@ getNullableColumn totalRows description chunks = unboxedGo decoder = foldNullableUnboxed maxDef totalRows $ (\(vs, ds, _) -> (vs, ds)) - <$> Stream.unfoldEach + <$> Stream.unfoldMany (readPages description decoder) (Stream.fromList chunks) @@ -499,7 +499,7 @@ getRepeatedColumn description chunks = m Column go decoder = foldRepeated maxRep maxDef $ - Stream.unfoldEach (readPages description decoder) (Stream.fromList chunks) + Stream.unfoldMany (readPages description decoder) (Stream.fromList chunks) unboxedGo :: forall a. @@ -513,7 +513,7 @@ getRepeatedColumn description chunks = m Column unboxedGo decoder = foldRepeatedUnboxed maxRep maxDef $ - Stream.unfoldEach + Stream.unfoldMany (readPages description decoder) (Stream.fromList chunks) diff --git a/examples/examples.cabal b/examples/examples.cabal index a24a941..9a37173 100644 --- a/examples/examples.cabal +++ b/examples/examples.cabal @@ -139,7 +139,7 @@ executable examples cassava >= 0.1 && < 1, containers >= 0.6.7 && < 0.9, directory >= 1.3.0.0 && < 2, - granite ^>= 0.6, + granite >= 0.6 && < 1, hashable >= 1.2 && < 2, hasktorch, http-conduit, diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..96f37ec --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1780243769, + "narHash": "sha256-x5UQuRsH3MqI0U9afaXSNqzTPSeZlRLvFAav2Ux1pNw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "331800de5053fcebacf6813adb5db9c9dca22a0c", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix index b5e3f8e..b1f06b8 100644 --- a/flake.nix +++ b/flake.nix @@ -14,14 +14,42 @@ repo = "granite"; owner = "mchav"; rev = "main"; - hash = "sha256-Z/o8gxMOBltKiaL0NEjMUyOvUljRvKErWeM6Ul3GM9k="; + hash = "sha256-jmmI2+kbqe+X/CDP986qQnUMGR35iNW5deNLovHpBHA="; + }; + pinchPkg = pkgs.fetchFromGitHub { + repo = "pinch"; + owner = "abhinav"; + rev = "v0.5.2.0"; + hash = "sha256-kuCS4EePc4aIONCvF0sOZt4pCazAq1z9+a/AY9b7Q6c="; + }; + networkRunPkg = pkgs.fetchFromGitHub { + repo = "network-run"; + owner = "kazu-yamamoto"; + rev = "v0.3.1"; + hash = "sha256-xyyf+Le2x9ACJBE4ua7wWHsfOQHNi7D+DksghZFh35I="; }; hsPkgs = pkgs.haskellPackages.extend (self: super: { granite = self.callCabal2nix "granite" granitePkg { }; + network-run = self.callCabal2nix "network-run" networkRunPkg { }; + pinch = self.callCabal2nix "pinch" pinchPkg { }; + dataframe-arrow = self.callCabal2nix "dataframe-arrow" ./dataframe-arrow { }; + dataframe-core = self.callCabal2nix "dataframe-core" ./dataframe-core { }; + dataframe-csv = self.callCabal2nix "dataframe-csv" ./dataframe-csv { }; + dataframe-csv-th = self.callCabal2nix "dataframe-csv-th" ./dataframe-csv-th { }; dataframe-fastcsv = self.callCabal2nix "dataframe-fastcsv" ./dataframe-fastcsv { }; - dataframe-persistent = self.callCabal2nix "dataframe-persistent" ./dataframe-persistent { }; + # dataframe-fusion = self.callCabal2nix "dataframe-fusion" ./dataframe-fusion { }; dataframe-hasktorch = self.callCabal2nix "dataframe-hasktorch" ./dataframe-hasktorch { }; + dataframe-json = self.callCabal2nix "dataframe-json" ./dataframe-json { }; + dataframe-lazy = self.callCabal2nix "dataframe-lazy" ./dataframe-lazy { }; + dataframe-learn = self.callCabal2nix "dataframe-learn" ./dataframe-learn { }; + dataframe-operations = self.callCabal2nix "dataframe-operations" ./dataframe-operations { }; + dataframe-parquet = self.callCabal2nix "dataframe-parquet" ./dataframe-parquet { }; + dataframe-parquet-th = self.callCabal2nix "dataframe-parquet-th" ./dataframe-parquet-th { }; + dataframe-parsing = self.callCabal2nix "dataframe-parsing" ./dataframe-parsing { }; + dataframe-persistent = self.callCabal2nix "dataframe-persistent" ./dataframe-persistent { }; + dataframe-th = self.callCabal2nix "dataframe-th" ./dataframe-th { }; + dataframe-viz = self.callCabal2nix "dataframe-viz" ./dataframe-viz { }; dataframe = self.callCabal2nix "dataframe" ./. { }; }); in @@ -29,17 +57,45 @@ packages = { default = hsPkgs.dataframe; dataframe = hsPkgs.dataframe; + dataframe-arrow = hsPkgs.dataframe-arrow; + dataframe-core = hsPkgs.dataframe-core; + dataframe-csv = hsPkgs.dataframe-csv; + dataframe-csv-th = hsPkgs.dataframe-csv-th; dataframe-fastcsv = hsPkgs.dataframe-fastcsv; + # dataframe-fusion = hsPkgs.dataframe-fusion; dataframe-hasktorch = hsPkgs.dataframe-hasktorch; + dataframe-json = hsPkgs.dataframe-json; + dataframe-lazy = hsPkgs.dataframe-lazy; + dataframe-learn = hsPkgs.dataframe-learn; + dataframe-operations = hsPkgs.dataframe-operations; + dataframe-parquet = hsPkgs.dataframe-parquet; + dataframe-parquet-th = hsPkgs.dataframe-parquet-th; + dataframe-parsing = hsPkgs.dataframe-parsing; dataframe-persistent = hsPkgs.dataframe-persistent; + dataframe-th = hsPkgs.dataframe-th; + dataframe-viz = hsPkgs.dataframe-viz; }; devShells.default = hsPkgs.shellFor { packages = ps: [ ps.dataframe + ps.dataframe-arrow + ps.dataframe-core + ps.dataframe-csv + ps.dataframe-csv-th ps.dataframe-fastcsv - ps.dataframe-persistent + # ps.dataframe-fusion ps.dataframe-hasktorch + ps.dataframe-json + ps.dataframe-lazy + ps.dataframe-learn + ps.dataframe-operations + ps.dataframe-parquet + ps.dataframe-parquet-th + ps.dataframe-parsing + ps.dataframe-persistent + ps.dataframe-th + ps.dataframe-viz ]; nativeBuildInputs = with hsPkgs; [ ghc