module Data.DecisionTree where import Data.List import Data.Ord (comparing) -- An attribute is anything that divides the dataset in two data Att a = Att {test :: a -> Bool, label :: String} instance Show (Att a) where show att = "Att(" ++ label att ++ ")" -- A decision tree data DTree a b = Result b | Decision (Att a) b (DTree a b) (DTree a b) deriving (Show) -- Run the decision tree on an example decide :: DTree a b -> a -> b decide (Result b) _ = b decide (Decision att _ tbranch fbranch) a = if test att a then decide tbranch a else decide fbranch a instance Functor (DTree a) where fmap f (Result b) = Result (f b) fmap f (Decision att b tbranch fbranch) = Decision att (f b) (fmap f tbranch) (fmap f fbranch) type Splitter a b = b -> Maybe (Att a,b,b) -- Repeatedly split nodes to form a DTree -- We leave the state information at the branches so that the tree can be intelligently pruned runSplitter :: Splitter a b -> b -> DTree a b runSplitter split b = run b where run b = case split b of Nothing -> Result b Just (att,b1,b2) -> Decision att b (run b1) (run b2) -- This isnt actually used yet splitOn :: Att a -> ([a] -> [a] -> b) -> [a] -> b splitOn att cont as = cont tlist flist where (tlist,flist) = partition (test att) as -- points [1,2,3] = [(1,[2,3]),(2,[1,3]),(3,[1,2])] -- Used because there is no Eq instance for Att so we can't delete them from lists points [] = [] points (a:as) = (a,as):[(b,a:bs) | (b,bs) <- points as] -- Split a node with the attribute which minimises valf minSplit :: Ord o => ([a] -> [a] -> o) -> Splitter a ([Att a],[a]) minSplit _ ([],_) = Nothing minSplit _ (_,[]) = Nothing minSplit valf (atts,as) = if null choices then Nothing else Just (att,(atts',tlist),(atts',flist)) where (att,atts',(tlist,flist)) = minimumBy (comparing (\(_,_,(ts,fs)) -> valf ts fs)) choices choices = filter (\(_,_,(ts,fs)) -> not (null ts || null fs)) $ [(att,atts',partition (test att) as) | (att,atts') <- points atts] -- Prune the tree after 'i' decisions maxDecisions :: Int -> DTree a b -> DTree a b maxDecisions i (Decision att b tbranch fbranch) = if i <= 0 then Result b else Decision att b (maxDecisions (i-1) tbranch) (maxDecisions (i-1) fbranch) maxDecisions _ r = r -- Prune decisions by predicate prune :: (b -> Bool) -> DTree a b -> DTree a b prune f r@(Result b) = r prune f d@(Decision att b tbranch fbranch) = if f b then Result b else d -- Useful as a value function for minsplit entropy as = negate $ sum $ map value $ group $ sort as where value g = let p = (flength g)/(flength as) in p * logBase 2 p flength g = fromIntegral $ length g