-- Copyright (c) 2010, Adam Crume
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the University of California nor the names of its
--       contributors may be used to endorse or promote products derived from
--       this software without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
-- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-- CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-- POSSIBILITY OF SUCH DAMAGE.


{-# LANGUAGE ScopedTypeVariables #-}

module Ein.Main (
                 ein,
                 eval,
                 eval',
                 initialContext,
                 initialContextWithCore
                ) where


import Control.Monad.State
import Data.List
import qualified Data.Map as Map
import qualified Data.Set as Set
import Ein.Data
import Ein.Parse
import Ein.Pattern
import qualified Ein.StandardSymbols as S
import Text.Parsec.String


-- | Returns the "head-most" symbol.  For example, the base symbol of f[1][2] is f.
--   Recurses into the first argument of a Condition.
baseSymbol :: Exp -> Exp
baseSymbol (ESym x) = (ESym x)
baseSymbol (EFunc f xs) = case f of
                            ESym "Condition" -> baseSymbol (xs!!0)
                            _ -> baseSymbol f


-- | Returns True if the expression has the attribute, False otherwise.
hasAtt' :: String -- ^ Attribute to look up
        -> Context -- ^ Context to look up the mapping in
        -> Exp -- ^ Expression which may have the attribute
        -> Bool -- ^ True if the expression has the attribute
hasAtt' a c e = case (Map.lookup e (contextAtts c)) of
                  Just s -> Set.member (ESym a) s
                  Nothing -> False


-- | Returns True if the expression has the attribute, False otherwise.
hasAtt :: String -- ^ Attribute to look up
       -> Exp -- ^ Expression which may have the attribute
       -> EC Bool -- ^ True if the expression has the attribute
hasAtt a e = do
  c <- get
  return $ hasAtt' a c e


-- expression context additionalRules
eval' :: Exp -> Context -> [Rule] -> (Exp, Context)
eval' e c rs = runState (eval e rs) c

eval :: Exp -> [Rule] -> EC Exp
eval (EInt x) rs = return (EInt x)
eval (EStr x) rs = return (EStr x)
eval x@(ESym _) rs =
    do
      contextPush x
      x' <- evalWithRules x rs
      result <- if x == x' then return x else eval x' rs
      contextPop
      return result
eval e@(EFunc f args) rs =
    do
      contextPush e
      head' <- eval f rs
      holdFirst <- hasAtt "HoldFirst" head'
      holdRest <- hasAtt "HoldRest" head'
      args' <- case args of
                 [] -> return []
                 (x:xs) -> do
                             x' <- if holdFirst then return x else eval x rs
                             xs' <- if holdRest then return xs else mapM (\y -> eval y rs) xs
                             return (x':xs')
      e' <- evalWithRules (EFunc head' args') rs
      result <- if e == e' then return e else eval e' rs
      contextPop
      return result


-- evalWithRules :: Exp -> [Rule] -> EC (Exp)
-- evalWithRules e rs =
--     do
--       c <- get
--       rules <- return $ contextLazyRules c
--       base <- return $ Map.lookup (baseSymbol e) (contextValues c)
--       rules <- return $ case base of
--                           Nothing -> rules
--                           Just x -> x ++ rules
--       rules <- return $ (contextEagerRules c) ++ rs ++ rules
--       result <- return $ foldl (foldExp c) Nothing rules
--       result <- case result of
--                   Nothing -> return e
--                   Just (x', c') -> do
--                               put c'
--                               return x'
--       return result
--     where
--       foldExp :: Context -> Maybe (Exp, Context) -> Rule -> Maybe (Exp, Context)
--       foldExp c x r = case x of
--                         Nothing -> case runState (
--                                                   do
--                                                     contextPush $ EStr $ "Rule: " ++ ruleString r
--                                                     c <- get
--                                                     val :: Maybe (Exp, Context) <- return $ (ruleFunc r) e c
--                                                     put $ case val of
--                                                             Nothing -> c
--                                                             Just (_, c') -> c'
--                                                     contextPop
--                                                     return $ case val of
--                                                                Nothing -> Nothing
--                                                                Just x -> Just $ fst x
--                                                  ) c of
--                                      (Nothing, _) -> Nothing
--                                      (Just x, c) -> Just (x, c)
--                         Just _ -> x


evalWithRules :: Exp -> [Rule] -> EC (Exp)
evalWithRules e rs =
    do
      c <- get
      let rules = contextLazyRules c
      rules <- return $ case Map.lookup (baseSymbol e) (contextValues c) of
                           Nothing -> rules
                           Just x -> x ++ rules
      rules <- return $ (contextEagerRules c) ++ rs ++ rules
      result <- return $ foldl (foldExp c) Nothing rules
      result <- case result of
                  Nothing -> return e
                  Just (x', c') -> do
                              put c'
                              return x'
      return result
    where
      foldExp :: Context -> Maybe (Exp, Context) -> Rule -> Maybe (Exp, Context)
      foldExp c x r = case x of
                        Nothing -> case runState (
                                                  do
                                                    contextPush $ EStr $ "Rule: " ++ ruleString r
                                                    c <- get
                                                    let val :: Maybe (Exp, Context) = (ruleFunc r) e c
                                                    case val of
                                                      Nothing -> return ()
                                                      Just (_, c') -> put c'
                                                    contextPop
                                                    return $! case val of
                                                                Nothing -> Nothing
                                                                Just x -> Just $! fst x
                                                 ) c of
                                     (Nothing, _) -> Nothing
                                     (Just x, c) -> Just (x, c)
                        Just _ -> x


------------------------- Context initialization

initialContext :: Context
initialContext = Context Map.empty (Map.fromList [(S.compound, [compoundRule]),
                                                  (S.setAttributes, [setAttributesRule]),
                                                  (S.set, [setRule]),
                                                  (S.setDelayed, [setRule]),
                                                  (S.plus, [addRule, hsCaseAddRule, hsCaseAddRule2]),
                                                  (S.times, [timesRule]),
                                                  (S.greater, [intComparisonRule (>) "Greater"]),
                                                  (S.less, [intComparisonRule (<) "Less"]),
                                                  (S.greaterOrEqual, [intComparisonRule (>=) "GreaterOrEqual"]),
                                                  (S.lessOrEqual, [intComparisonRule (<=) "LessOrEqual"]),
                                                  (S.map, [mapRule]),
                                                  (S.equal, [equalRule, hsCaseEqualRule]),
                                                  (S.matchQ, [matchQRule]),
                                                  (S.count, [countRule])
                                                 ])
                 [flattenRule]
                 [powerRule, sequenceRule, functionRule]
                 0
                 []


compoundRule' (EFunc (ESym "Compound") args) c = Just $ evalList c args
compoundRule' _ _ = Nothing
compoundRule = Rule "<compoundRule>" compoundRule'


evalList c [] = (S.null, c)
evalList c (x:[]) = eval' x c []
evalList c (x:xs) = evalList ((\(e,c) -> c) (eval' x c [])) xs


setAttributesRule' (EFunc (ESym "SetAttributes") (sym:atts:[])) c =
    case atts of
      (EFunc (ESym "List") args) -> Just (S.null, addAtts sym args c)
      _ -> Just (S.null, addAtt sym atts c)
setAttributesRule' _ _ = Nothing
setAttributesRule = Rule "<setAttributesRule>" setAttributesRule'


setRule' (EFunc head (left:right:[])) c
    | head == S.set || head == S.setDelayed =
        let
            head' = if head == S.set then S.rule else S.ruleDelayed
        in
          Just (S.null, addVal (baseSymbol left) (expressionRule $ EFunc head' [left, right]) c)
setRule' _ _ = Nothing
setRule = Rule "<setRule>" setRule'


addRule' (EFunc (ESym "Plus") args) c =
    let
        sumInts [] = []
        sumInts (x:[]) = [x]
        sumInts (x:xs) =
            case (x, sumInts xs) of
              (EInt y, EInt y':ys) -> EInt (y+y'):ys
              (EInt y, ys) -> EInt y:ys
              (y, EInt y':ys) -> EInt y':y:ys
              (y, ys) -> y:ys
    in
      case sumInts args of
        [] -> Just (EInt 0, c)
        (x:[]) -> Just (x, c)
        xs | xs /= args -> Just (EFunc S.plus xs, c)
        _ -> Nothing
addRule' _ _ = Nothing
addRule = Rule "<addRule>" addRule'


timesRule' (EFunc (ESym "Times") args) c =
    let
        prodInts [] = []
        prodInts (x:[]) = [x]
        prodInts (x:xs) =
            case (x, prodInts xs) of
              (EInt y, EInt y':ys) -> EInt (y*y'):ys
              (EInt y, ys) -> EInt y:ys
              (y, EInt y':ys) -> EInt y':y:ys
              (y, ys) -> y:ys
    in
      case prodInts args of
        [] -> Just (EInt 1, c)
        (x:[]) -> Just (x, c)
        xs | xs /= args -> Just (EFunc S.times xs, c)
        _ -> Nothing
timesRule' _ _ = Nothing
timesRule = Rule "<timesRule>" timesRule'


powerRule' (EFunc (ESym "Power") args) c =
    let
        powInts [] = []
        powInts (x:[]) = [x]
        powInts (x:xs) =
            case (x, powInts xs) of
              (EInt y, EInt y':ys) | y >= 0 && y' >= 0 -> EInt (y^y'):ys
              (y, ys) -> y:ys
    in
      case powInts args of
        (x:[]) -> Just (x, c)
        xs | xs /= args -> Just (EFunc S.power xs, c)
        _ -> Nothing
powerRule' _ _ = Nothing
powerRule = Rule "<powerRule>" powerRule'


flattenRule' (EFunc head args) c =
    if hasAtt' "Flat" c head
    then
        let
            args' = concatMap (\x -> case x of EFunc head' args' | head' == head -> args'; _ -> [x]) args
        in
          if args' == args then Nothing else Just (EFunc head args', c)
    else Nothing
flattenRule' _ _ = Nothing
flattenRule = Rule "<flattenRule>" flattenRule'


sequenceRule' (EFunc head args) c =
    if hasAtt' "SequenceHold" c head
    then Nothing
    else
        let
            f [] = []
            f ((EFunc (ESym "Sequence") xs):ys) = xs ++ ys
            f (x:xs) = x:(f xs)
            args' = f args
        in
          if args' == args then Nothing else Just (EFunc head args', c)
sequenceRule' _ _ = Nothing
sequenceRule = Rule "<sequenceRule>" sequenceRule'


mapRule' (EFunc (ESym "Map") (f:(EFunc head args):[])) c =
    Just $ (EFunc head $ map (\x->EFunc f [x]) args, c)
mapRule' _ _ = Nothing
mapRule = Rule "<mapRule>" mapRule'


intComparisonRule' :: (Integer -> Integer -> Bool) -> String -> Exp -> Context -> Maybe (Exp, Context)
intComparisonRule' f h (EFunc head args) c
    | head == ESym h =
        let cmp f args =
                case args of
                  [] -> Just (S.true, c)
                  (x:[]) -> Just (S.true, c)
                  ((EInt x):(EInt x'):xs) -> case (f x x', cmp f ((EInt x'):xs)) of
                                               (True, Just(y, _)) | y == S.true -> Just (S.true, c)
                                               (False, _) -> Just (S.false, c)
                                               (_, Just (y, _)) | y == S.false -> Just (S.false, c)
                                               _ -> Nothing
                  _ -> Nothing
        in
          cmp f args
intComparisonRule' _ _ _ _ = Nothing
intComparisonRule f h = Rule ("<" ++ h ++ ">") (intComparisonRule' f h)


equalRule' (EFunc (ESym "Equal") args) c =
    let
        eq' (EInt x) (EInt y) = Just (x == y)
        eq' (ESym x) (ESym y) = if x == y then Just True else Nothing
        eq' (EStr x) (EStr y) = Just (x == y)
        eq' x y | x==y = Just True
        eq' _ _ = Nothing
        f [] = Just True
        f (x:[]) = Just True
        f (x:x':xs) = case (eq' x x', f (x':xs)) of
                        (Just True, Just True) -> Just True
                        (Just False, _) -> Just False
                        (_, Just False) -> Just False
                        (_, _) -> Nothing
    in
      case f args of
        Just True -> Just (S.true, c)
        Just False -> Just (S.false, c)
        Nothing -> Nothing
equalRule' _ _ = Nothing
equalRule = Rule "<equalRule>" equalRule'


matchQRule' (EFunc (ESym "MatchQ") (x:p:[])) c =
    case match p c x of
      [] -> Just (S.false, c)
      _ -> Just (S.true, c)
matchQRule' _ _ = Nothing
matchQRule = Rule "<matchQRule>" matchQRule'


countRule' :: Exp -> Context -> Maybe (Exp, Context)
countRule' (EFunc (ESym "Count") (e:p:[])) c =
    countRule' (EFunc S.count [e, p, list [EInt 1, EInt 1]]) c
countRule' (EFunc (ESym "Count") (e:p:levels:[])) c =
    let
        f e l max =
            case max of
              Just x | x < 0 -> 0
              Just x -> g (Just (x-1))
              Nothing -> g Nothing
            where
              l' = if l > 0 then (l-1) else l
              h max' = case e of
                         (EFunc _ args) -> sum $ map (\x -> f x l' max') args
                         _ -> 0
              g max' = case match p c e of
                         [] -> (h max')
                         _ -> (h max') + 1
        f' min max = Just (EInt $ f e min max, c)
    in
      case levels of
        EInt 0 -> f' 0 (Just 0)
        EInt x | x>0 -> f' 1 (Just x)
        ESym "Infinity" -> f' 1 Nothing
        EFunc (ESym "List") (EInt x:[]) -> f' x (Just x)
        EFunc (ESym "List") (EInt min:EInt max:[]) -> f' min (Just max)
        EFunc (ESym "List") (EInt min:ESym "Infinity":[]) -> f' min Nothing
        _ -> Nothing
countRule' _ _ = Nothing
countRule = Rule "<countRule>" countRule'


-- TODO: I'm sure this isn't quite right.  May need to make sure arguments don't get evaluated multiple times, support renaming of arguments, etc.
functionRule' :: Exp -> Context -> Maybe (Exp, Context)
functionRule' (EFunc (EFunc (ESym "Function") ((EFunc (ESym "List") vars):body:[])) apps) c
    | length vars <= length apps =
        let
            m = Map.fromList $ zip vars $ map (\x->[x]) apps
        in
          Just (genRep m body, c)
functionRule' _ _ = Nothing
functionRule = Rule "<functionRule>" functionRule'


hsCaseAddRule' :: Exp -> Context -> Maybe (Exp, Context)
hsCaseAddRule' (EFunc (ESym "Plus") (x : EFunc (ESym "HsCase") [e, EFunc (ESym "List") rules] : xs)) c =
    let
        fixRule (EFunc h [left, right])
                | h == S.rule || h == S.ruleDelayed =
                    EFunc h [left, EFunc S.plus [x, right]]
        rules' = map fixRule rules
        case' = EFunc (ESym "HsCase") [e, EFunc S.list rules']
    in
      Just (EFunc S.plus (case' : xs), c)
hsCaseAddRule' _ _ = Nothing
hsCaseAddRule = Rule "<hsCaseAddRule>" hsCaseAddRule'


hsCaseAddRule2' :: Exp -> Context -> Maybe (Exp, Context)
hsCaseAddRule2' (EFunc (ESym "Plus") (EFunc (ESym "HsCase") [e, EFunc (ESym "List") rules] : x : xs)) c =
    let
        fixRule (EFunc h [left, right])
                | h == S.rule || h == S.ruleDelayed =
                    EFunc h [left, EFunc S.plus [right, x]]
        rules' = map fixRule rules
        case' = EFunc (ESym "HsCase") [e, EFunc S.list rules']
    in
      Just (EFunc S.plus (case' : xs), c)
hsCaseAddRule2' _ _ = Nothing
hsCaseAddRule2 = Rule "<hsCaseAddRule2>" hsCaseAddRule2'


hsCaseEqualRule' :: Exp -> Context -> Maybe (Exp, Context)
hsCaseEqualRule' (EFunc (ESym "Equal") [EFunc f args, EFunc (ESym "HsCase") [e, EFunc (ESym "List") rules]]) c =
    if e `elem` args
    then
        let
            unpattern (EFunc (ESym "Pattern") [x, EFunc (ESym "Blank") []]) = x
            unpattern (EInt x) = EInt x
            ruleToEq (EFunc h [left, right])
                | h == S.rule || h == S.ruleDelayed =
                    let fixArg x = if x == e then unpattern left else x
                    in EFunc S.equal [EFunc f $ map fixArg args, right]
        in
          Just (EFunc S.sequence $ map ruleToEq rules, c)
    else Nothing
hsCaseEqualRule' _ _ = Nothing
hsCaseEqualRule = Rule "<hsCaseEqualRule>" hsCaseEqualRule'


------------------------- End context initialization


initialContextWithCore =
    do
      result <- parseFromFile input "Ein/core.ein"
      case result of
        Left err -> error $ show err
        Right e -> return $ snd $ eval' e initialContext []


ein :: String -> IO ()
ein s =
    do
      result <- parseFromFile input "Ein/core.ein"
      case result of
        Left err -> print err
        Right e -> do
                    (_, c) <- return $ eval' e initialContext []
                    (e, c) <- return $ eval' (parse' s) c []
                    print e


main :: IO ()
main =
    let
        buildAtts a = Map.fromList $ map (\(x,y) -> (ESym x, Set.fromList (map ESym y))) a
        m = buildAtts [("Plus", ["Flat", "OneIdentity"]),
                       ("Times", ["Flat", "OneIdentity"])]
        c0 = emptyContext {contextAtts = m}
    in do
      print $ showExpanded $ parse' "LessOrEqual[a_, b_]/;Less[a,b]:=True"
      print ((1+2+3+4)::Exp)
      result <- parseFromFile input "Ein/core.ein"
      case result of
        Left err -> print err
        Right e -> do
                    print e
                    (e', c') <- return $ eval' e initialContext []
                    print "-------------------"
                    print e'
                    print $ contextAtts c'
                    --print $ contextValues c'
                    (e'', c'') <- return $ eval' (parse' "x=1;x+2") c' []
                    print "-------------------"
                    print e''
                    (e, _) <- return $ eval' (parse' "x=2; 1+x+y+2") c' []
                    print "-------------------"
                    print e
                    (e, _) <- return $ eval' (parse' "Less[1,2]") c' []
                    print "-------------------"
                    print e
                    (e, _) <- return $ eval' (parse' "Less[2,1]") c' []
                    print "-------------------"
                    print e
                    (e, _) <- return $ eval' (parse' "x=2; Less[1+x+y+2, 10]") c' []
                    print "-------------------"
                    print e
                    (e, _) <- return $ eval' (parse' "Power[2,3]") c' []
                    print "-------------------"
                    print e
                    (e, _) <- return $ eval' (parse' "Foo[a,Sequence[b,c],d]") c' []
                    print "-------------------"
                    print e
                    (e, _) <- return $ eval' (parse' "SetAttributes[Bar, SequenceHold]; Bar[a,Sequence[b,c],d]") c' []
                    print "-------------------"
                    print e
