-- Copyright (c) 2010, Adam Crume
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the University of California nor the names of its
--       contributors may be used to endorse or promote products derived from
--       this software without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
-- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-- CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-- POSSIBILITY OF SUCH DAMAGE.


module RecurrenceGenerator (
                            analyze
                           ) where

import Bag
import Control.Monad
import Data.List
import qualified Data.Map as Map
import Data.Ratio
import qualified Data.Set as Set
import DataCon
import DepAnalysis
import DynFlags
import Ein.Data
import Ein.Main
import Ein.Pattern
import qualified Ein.StandardSymbols as S
import GHC
import GHC.Paths
import LinearMath
import Outputable
import Name
import System
import Var


-- | Converts an Outputable to a string.  This should be in the Outputable module, but maybe I'm using an old version:
showPpr :: Outputable a => a -> String
showPpr = showSDoc . ppr


-- | Returns recurrence equations for a function from a binding.
handleBind :: OutputableBndr id => String -- ^ Name of the function
           -> HsBind id -- ^ Some binding
           -> [Exp] -- ^ Recurrence equations
handleBind f x@(AbsBinds tvs dicts exports binds) =
    concatMap (handleBind f . unLoc) (bagToList binds)
handleBind f x@(FunBind fid finfix matches cofn fvs tick) =
    case showPpr fid of
      f' | f' == f -> handleMatchGroup f matches
      _ -> []
handleBind f x =
    []


-- | Returns recurrence equations for a function from a match group.
handleMatchGroup :: OutputableBndr id => String -- ^ Name of the function
                 -> MatchGroup id -- ^ Match group
                 -> [Exp] -- ^ Recurrence equations
handleMatchGroup f x@(MatchGroup matches type_) =
    concatMap (handleMatch f . unLoc) matches


-- | Returns recurrence equations for a function from a match.
handleMatch :: OutputableBndr id => String -- ^ Name of the function
            -> Match id -- ^ Match
            -> [Exp] -- ^ Recurrence equations
handleMatch f x@(Match pat mtype grhs) =
    let
        lhs = EFunc (ESym f) (map (patToExp . unLoc) pat)
        rhs = handleGRHSs grhs
    in
      [EFunc S.equal [lhs, EFunc (S.plus) [EInt (toInteger $ length pat), rhs]]]


-- | Converts a GRHSs to a time expression.
handleGRHSs :: OutputableBndr id => GRHSs id -> Exp
handleGRHSs (GRHSs rhss binds) = handleLGRHS (rhss !! 0) -- We should handle cases with multiple guards


-- | Converts an LGRHS to a time expression.
handleLGRHS :: OutputableBndr id => LGRHS id -> Exp
handleLGRHS (L _ (GRHS stmts e)) = lhsExprToTimeExp e -- We should handle guards


-- | Converts an LHsExpr to an equivalent expression (not a time expression)
lhsExprToExp :: OutputableBndr id => LHsExpr id -> Exp
lhsExprToExp (L _ (HsVar v)) = ESym (showPpr v)
lhsExprToExp (L _ (HsOverLit v)) = overLitToExp v
lhsExprToExp (L _ (HsApp f x)) =
    case lhsExprToExp f of
      EFunc h args -> EFunc h $ args ++ [lhsExprToExp x]
      h -> EFunc h [lhsExprToExp x]
lhsExprToExp e@(L _ (OpApp x op fy y)) =
    case showPpr op of -- TODO: Find a better way to match on the op
      "(+)" -> EFunc S.plus [lhsExprToExp x, lhsExprToExp y]
      "(-)" -> EFunc S.plus [lhsExprToExp x, EFunc S.times [EInt (-1), lhsExprToExp y]]
      "(*)" -> EFunc S.times [lhsExprToExp x, lhsExprToExp y]
      "(/)" -> EFunc S.times [lhsExprToExp x, EFunc S.power [lhsExprToExp y, EInt (-1)]]
      "(GHC.Types.:)" -> EFunc S.plus [EInt 1, lhsExprToExp y]
      z -> error $ "lhsExprToExp: Unrecognized op: " ++ z
lhsExprToExp (L _ (HsPar x)) = lhsExprToExp x
lhsExprToExp (L _ (HsCase e matchGroup)) =
    let
        caseGRHSs (GRHSs rhss binds) = caseLGRHS (rhss !! 0) -- We should handle cases with multiple guards
        caseLGRHS (L _ (GRHS stmts e)) = lhsExprToExp e -- We should handle guards
        caseMatch (L _ (Match pats _ rhss)) = EFunc S.rule [patToExp' $ unLoc $ pats!!0, caseGRHSs rhss]
        (MatchGroup matches _) = matchGroup
        rules = map caseMatch matches
    in
      EFunc (ESym "HsCase") [lhsExprToExp e, EFunc S.list rules]
lhsExprToExp (L _ x) = error $ "No lhsExprToExp for " ++ (showSDocDebug $ ppr x)


-- | Converts an LHsExpr to a time expression.
lhsExprToTimeExp :: OutputableBndr id => LHsExpr id -> Exp
lhsExprToTimeExp (L _ (HsVar v)) = EInt 0
lhsExprToTimeExp (L _ (HsOverLit v)) = EInt 0
lhsExprToTimeExp e@(L _ (HsApp f x)) =
    let
        argsOf f =
            case f of
              (L _ (HsApp f' x)) -> argsOf f' ++ [x]
              _ -> []
        call = lhsExprToExp e
        args = argsOf e
        argcost = EFunc S.plus $ map lhsExprToTimeExp args
    in
      EFunc S.plus [argcost, call]
lhsExprToTimeExp e@(L _ (OpApp x op fy y)) =
    case showPpr op of -- TODO: Find a better way to match on the op
      "(+)" -> EFunc S.plus [EInt 3, lhsExprToTimeExp x, lhsExprToTimeExp y]
      "(-)" -> EFunc S.plus [EInt 3, lhsExprToTimeExp x, lhsExprToTimeExp y]
      "(*)" -> EFunc S.plus [EInt 3, lhsExprToTimeExp x, lhsExprToTimeExp y]
      "(/)" -> EFunc S.plus [EInt 3, lhsExprToTimeExp x, lhsExprToTimeExp y]
      "(GHC.Types.:)" -> EFunc S.plus [EInt 3, lhsExprToTimeExp x, lhsExprToTimeExp y]
      z -> error $ "lhsExprToTimeExp: Unrecognized op: " ++ z
lhsExprToTimeExp (L _ (HsPar x)) = lhsExprToTimeExp x
lhsExprToTimeExp (L _ (HsCase e matchGroup)) =
    let
        caseGRHSs (GRHSs rhss binds) = caseLGRHS (rhss !! 0) -- We should handle cases with multiple guards
        caseLGRHS (L _ (GRHS stmts e)) = lhsExprToTimeExp e -- We should handle guards
        caseMatch (L _ (Match pats _ rhss)) = EFunc S.rule [patToExp' $ unLoc $ pats!!0, caseGRHSs rhss]
        (MatchGroup matches _) = matchGroup
        rules = map caseMatch matches
    in
      EFunc (ESym "HsCase") [lhsExprToExp e, EFunc S.list rules]
lhsExprToTimeExp (L _ x) = error $ "No lhsExprToTimeExp for " ++ (showSDocDebug $ ppr x)


patToExp :: OutputableBndr id => Pat id -> Exp
patToExp (NPat n _ _) = overLitToExp n
patToExp (VarPat v) = ESym $ showPpr v
patToExp (ParPat p) = patToExp $ unLoc p
patToExp (ConPatOut con tvs dicts binds args ty)
    | (occNameString $ nameOccName $ dataConName $ unLoc con) == "[]" =
        EInt 0
patToExp (ConPatOut con tvs dicts binds args ty)
    | (occNameString $ nameOccName $ dataConName $ unLoc con) == ":" =
        let
            [L _ (VarPat x), L _ (VarPat xs)] = hsConPatArgs args
            x' = ESym $ showPpr x
            xs' = ESym $ showPpr xs
        in
          EFunc S.plus [EInt 1, xs']
patToExp x = error $ "patToExp not yet supported for " ++ showPpr x


-- | Converts a pattern to an expression.  This one makes an actual pattern.
patToExp' :: OutputableBndr id => Pat id -> Exp
patToExp' (NPat n _ _) = overLitToExp n
patToExp' (VarPat v) = EFunc S.pattern [ESym $ showPpr v, EFunc S.blank []]


-- | Converts a literal to an expression.
overLitToExp :: OutputableBndr id => HsOverLit id -> Exp
overLitToExp (OverLit (HsIntegral n) _ _ _) = EInt n
overLitToExp (OverLit (HsFractional n) _ _ _) = EFunc S.times [EInt (numerator n), EFunc S.power [EInt $ denominator n, EInt (-1)]]


-- | Removes a namespace from variable names if it exists, recursively.
trimNS :: String -> Exp -> Exp
trimNS s (EInt x) = EInt x
trimNS s (ESym x) = if (s ++ ".") `isPrefixOf` x
                    then ESym $ drop (length (s ++ ".")) x
                    else ESym x
trimNS s (EFunc f args) = EFunc (trimNS s f) $ map (trimNS s) args


-- | Generates recurrence equations for the execution time of a function.
genRecurrences :: String -- ^ Module name the function resides in
               -> String -- ^ Name of the function
               -> IO [Exp] -- ^ Recurrence equations
genRecurrences modName f =
    do
      binds <- parseHaskell modName
      return $ map (trimNS modName) $ concatMap (handleBind f . unLoc) binds


-- | Parses a Haskell module into typechecked bindings.
parseHaskell :: String -> IO [Located (HsBind Var)]
parseHaskell modName =
    runGhc (Just libdir) $ do
      flags <- getSessionDynFlags
      setSessionDynFlags $ setTmpDir "/tmp" flags
      addTarget $ Target {
                      targetId = TargetModule $ mkModuleName modName,
                      targetAllowObjCode = False,
                      targetContents = Nothing
                    }
      graph <- depanal [] True
      modSummary <- getModSummary $ mkModuleName modName
      mod <- parseModule modSummary
      mod' <- typecheckModule mod
      return $ bagToList $ typecheckedSource mod'


-- | Returns lists of coefficients for each unique set of powers of variables.
--   Only works with polynomials.  Will possibly return incorrect result or fail otherwise.
collectCoefficients :: [Exp] -- ^ Terms of the polynomial
                    -> [Exp] -- ^ Variables
                    -> Map.Map Exp [Exp] -- ^ Map of base terms to lists of coefficients
collectCoefficients [] vars = Map.empty
collectCoefficients (x:xs) vars =
-- TODO: Rewrite using fold
    let
        m = collectCoefficients xs vars
        varPow :: Exp -> Exp -> Integer
        varPow v (EInt _) = 0
        varPow v s@(ESym _) = if s==v then 1 else 0
        varPow v (EFunc (ESym "Power") (e:EInt p:[])) = p*(varPow v e)
        varPow v (EFunc (ESym "Times") args) = sum $ map (varPow v) args
        varPow v e = error $ "varPow called with: " ++ show e
        powers = EFunc S.times $ map (\v -> EFunc S.power [v, (EInt (varPow v x))]) vars
        removeVars [] = []
        removeVars (e@(EInt _):es) = e : removeVars es
        removeVars (e@(ESym _):es) =
            if elem e vars
            then removeVars es
            else e : removeVars es
        removeVars ((EFunc (ESym "Times") args):es) = EFunc S.times (removeVars args) : removeVars es
        removeVars (e@(EFunc (ESym "Power") (ESym s:EInt p:[])):es) =
            if elem (ESym s) vars
            then removeVars es
            else e : removeVars es
        oldList = Map.lookup powers m
        newList =
            case oldList of
              Nothing -> removeVars [x]
              Just y -> removeVars [x] ++ y
    in
      Map.insert powers newList m


-- | Converts a linear expression to a list of coefficients and a constant term.
linearCoefficients :: Exp -- ^ Linear expression
                   -> [Exp] -- ^ Variables
                   -> ([Integer], Integer) -- ^ Coefficients and a constant term
linearCoefficients (EFunc (ESym "Plus") terms) vars =
    let
        constant (EInt x) = x
        constant _ = 0
        coef v e@(ESym _) = if v==e then 1 else 0
        coef v e@(EFunc (ESym "Times") (EInt c:s:[])) = if v==s then c else 0
        coef v e@(EInt _) = 0
        coef v e = error $ "Bad term: " ++ show e
    in
      (map (\v -> sum (map (coef v) terms)) vars, -sum (map constant terms))
linearCoefficients e vars = linearCoefficients (EFunc S.plus [e]) vars


-- | Creates products of non-zero integer powers of variables, where the powers sum to a given integer.
createPolyTerms :: Integer -- ^ What the powers should sum to
                -> [Exp] -- ^ Variables
                -> [Exp] -- ^ Products of powers of variables
createPolyTerms power vars =
    let
        -- (g n s) returns a list of lists of size n of non-negative integers that sum to s
        g 0 0 = [[]]
        g 0 _ = []
        g n s = concatMap (\x -> map (x:) $ g (n-1) (s-x)) [0..s]
        factor v p = EFunc S.power [v, EInt $ toInteger p]
        terms :: Integer -> [[Exp]]
        terms p = map (zipWith factor vars) $ g (length vars) p
        terms' :: Integer -> [Exp]
        terms' p = map (EFunc S.times) $ terms p
    in
      concatMap terms' [0..power]


-- | Creates products of terms and coefficients.
createPoly :: [Exp] -- ^ Terms
           -> [Exp] -- ^ Coefficients
           -> [Exp] -- ^ Products of each term and its corresponding coefficient
createPoly terms coefs =
    zipWith (\c (EFunc (ESym "Times") factors) -> EFunc S.times $ c : factors) coefs terms


-- | Renames function arguments so multiple definitions of the same function use consistent argument names.
fixargs :: Map.Map Exp [Exp] -- ^ Maps function names to lists of argument names
        -> Exp -- ^ Equation to fix
        -> Exp -- ^ Equation with function arguments renamed
fixargs newargs (EFunc (ESym "Equal") (EFunc head args:argss)) =
    let
        Just newargs' = Map.lookup head newargs
        argmap = foldl genmap' Map.empty $ zip args newargs'
        genmap' m (oarg, narg) =
            case oarg of
              name@(ESym _) -> Map.insert name [narg] m
              EInt _ -> m
              EFunc (ESym "Plus") [EInt x, y@(ESym _)] -> Map.insert y [EFunc S.plus [narg, EFunc S.times [EInt (-1), EInt x]]] m
        args' = genRep' argmap args
        argss' = genRep' argmap argss
    in
      (EFunc (ESym "Equal") (EFunc head args':argss'))


main :: IO ()
main =
    do
      args <- getArgs
      when (length args /= 2) $
           error "Usage: analyze <module> <function>"
      c <- initialContextWithCore
      e <- analyze True c (args!!0) (args!!1)
      print e


-- | Analyzes the execution time of a function.
analyze :: Bool -- ^ True if we should print debugging output
        -> Context -- ^ Math context, call @initialContextWithCore@ to get one
        -> String -- ^ Module name
        -> String -- ^ Name of the function to analyze
        -> IO Exp -- ^ Formula for the expected execution time of the function
analyze verbose c modName f =
    do
      let argPrefix = "arg"
      let coefPrefix = "c"
      let polyDegree = 2
      let debug s = if verbose then putStrLn s else return ()
      debug "============="
      debug $ "Analyzing " ++ f
      debug "-------------"
      let genR = genRecurrences modName
      binds <- parseHaskell modName
      exps <- Set.fold (\dep recs ->
                            do
                              tmp <- genR dep
                              tmp2 <- recs
                              return $ tmp ++ tmp2
                       ) (return []) $ Set.insert f $ collectDependencies modName binds f
      sequence $ map (debug . show) exps
      debug "-------------"
      let collectNargs e m =
              let (EFunc (ESym "Equal") (EFunc h args:_)) = e
              in Map.insert h (length args) m
      let nargs = foldr collectNargs Map.empty exps
      let args = Map.mapWithKey (\(ESym f) n ->
                                     map (\x -> ESym $ argPrefix ++ show x ++ f) [0..n - 1]
                                ) nargs
      debug $ "nargs: " ++ show nargs
      debug $ "args: " ++ show args
      debug "-------------"
      --c <- initialContextWithCore
      EFunc (ESym "List") exps <- return $ fst $ eval' (EFunc S.list exps) c []
      debug "exps:"
      exps <- return $ map (fixargs args) exps
      sequence $ map (debug . show) exps
      debug "-------------"
      polyTerms <- return $ Map.mapWithKey (\(ESym f) args -> createPolyTerms polyDegree args) args
      polyCoefs <- return $ Map.mapWithKey (\(ESym f) terms ->
                                                map (\x -> ESym $ coefPrefix ++ show x ++ f) [0..(length terms) - 1]
                                           ) polyTerms
      c <- return $ Map.foldWithKey (\f args c ->
                                         let
                                             poly = createPoly (polyTerms Map.! f) (polyCoefs Map.! f)
                                             fpat = EFunc f $ map (\x -> EFunc S.pattern [x, EFunc S.blank []]) args
                                             fdef = EFunc S.setDelayed [fpat, EFunc S.plus poly]
                                         in
                                           snd $ eval' fdef c []
                                    ) c args
      debug $ "polyCoefs: " ++ show polyCoefs
      debug $ "polyTerms: " ++ show polyTerms
      exps <- return $ fmap (fst . (\x->eval' x c [])) exps
      debug "exps:"
      sequence $ map (debug . show) exps
      debug "-------------"
      exps <- return $ fmap (fst . (\x->eval' (EFunc (ESym "Leftify") [x]) c [])) exps
      sequence $ map (debug . show) exps
      debug "-------------"
      exps <- return $ fmap (fst . (\x->eval' (EFunc (ESym "ExpandAll") [x]) c [])) exps
      sequence $ map (debug . show) exps
      debug "-------------"
      leftSides <- return $ fmap (\(EFunc (ESym "Equal") (lhs:rhs:[])) -> lhs) exps
      terms <- return $ fmap (\x -> case x of (EFunc (ESym "Plus") args) -> args; _ -> [x]) leftSides
      let args' = Map.fold (++) [] args
      maps <-return $ fmap (\x -> collectCoefficients x args') terms
      sequence $ map (debug . show) maps
      debug "-------------"
      sums <- return $ concat $ fmap (
                                     \m ->
                                         map (\x -> EFunc S.plus x) (Map.elems m)
                                     ) maps
      sequence $ map (debug . show) sums
      debug "-------------"
      sums <- return $ fmap (fst . (\x->eval' x c [])) sums
      sequence $ map (debug . show) sums
      debug "-------------"
      let Just fPolyCoefs' = Map.lookup (ESym f) polyCoefs
      let polyCoefs' = fPolyCoefs' ++  Map.fold (++) [] (Map.delete (ESym f) polyCoefs)
      coefs <- return $ map (\x -> linearCoefficients x polyCoefs') sums
      sequence $ map (debug . show) coefs
      debug "-------------"
      matrix <- return $ map fst coefs
      values <- return $ map snd coefs
      sequence $ map (debug . show) matrix
      sequence $ map (debug . show) values
      debug $ "Matrix size: " ++ show (length matrix) ++ "x" ++ show (length $ matrix!!0)
      debug $ "Vector size: " ++ show (length values)
      debug "-------------"
      solution <- return $ solveLinearSystem (map (map (%1)) matrix) (map (%1) values)
      let fSolution = take (length fPolyCoefs') solution
      debug $ "solution: " ++ show solution
      debug $ "fSolution: " ++ show fSolution
      debug "-------------"
      let Just fPolyTerms = Map.lookup (ESym f) polyTerms
      debug $  "fPolyTerms: " ++ show fPolyTerms
      solutionExp <- return $ EFunc S.plus (zipWith (\s x -> EFunc S.times [ratioToExp s, x]) fSolution fPolyTerms)
      debug $ show solutionExp
      debug "-------------"
      solutionExp' <- return $ trimArgs argPrefix "x" f $ fst $ eval' solutionExp c []
      debug $ "Final solution for " ++ f ++ ": " ++ show solutionExp'
      return solutionExp'


-- | Changes prefixes and drop suffixes for arguments.
trimArgs :: String -- ^ Prefix to remove
         -> String -- ^ Prefix to add
         -> String -- ^ Suffix to remove
         -> Exp -- ^ Expression to fix
         -> Exp -- ^ Expression with arguments renamed
trimArgs prefix prefix' f (EInt x) = EInt x
trimArgs prefix prefix' f (ESym x) =
    if prefix `isPrefixOf` x
    then
        let x' = drop (length prefix) x
        in ESym $ prefix' ++ take (length x' - length f) x'
    else ESym x
trimArgs prefix prefix' f (EFunc h args) =
    EFunc (trimArgs prefix prefix' f h) $ map (trimArgs prefix prefix' f) args


-- | Converts a Ratio to an expression.
ratioToExp :: Integral a => Ratio a -> Exp
ratioToExp x = EFunc S.times [EInt $ toInteger $ numerator x, EFunc S.power [EInt $ toInteger $ denominator x, EInt (-1)]]
