--- ----------------------------------------------------------------------------
--- This module provides a transformation for typed FlatCurry which changes
--- the type of locally polymorphic sub-expressions to `()` (unit type).
---
--- A locally polymorphic sub-expression is an expression whose type
--- contains one or more type variables which are not included in the type
--- of the enclosing function declaration.
---
--- For instance, the function
---
---    f :: Bool
---    f = null []
---
--- where `[]` has the type `[a]`
---
--- will be transformed into
---
---    f :: Bool
---    f = null ([] :: [()])
---
--- This allows the above mentioned kind of functions to be compiled into
--- Haskell without typing errors.
---
--- @author Björn Peemöller
--- @version July 2013
--- ----------------------------------------------------------------------------

module KiCS2.DefaultPolymorphic (defaultPolymorphic) where

import Data.List ((\\), nub)

import Control.Monad.Trans.State
import Data.Map
import FlatCurry.Annotated.Types
import FlatCurry.Annotated.Goodies
import KiCS2.FlatCurry.Annotated.TypeSubst

--- The Default Polymorphic Monad contains a type substitution as its state.
type DPM a = State AFCSubst a

--- The default type local type variables get replaced with.
defaultType :: TypeExpr
defaultType = TCons ("Prelude", "()") []

--- Transform a typed FlatCurry program by replacing all local type variables
--- with the `defaultType`.
defaultPolymorphic :: AProg TypeExpr -> AProg TypeExpr
defaultPolymorphic = updProgFuncs (map dpFunc)

--- Transform a single function.
--- For a function, first the rule is transformed, resulting in an updated
--- rule and a type substitution. The substitution is then applied to the
--- whole function to ensure that the annotated type information stays
--- consistent. If the function is generated by the dictionary transformation
--- in the front end nothing is done.
dpFunc :: AFuncDecl TypeExpr -> AFuncDecl TypeExpr
dpFunc afunc@(AFunc f k v t r)
  | take 6 (snd f) `elem` ["_impl#", "_inst#"] = afunc
  | otherwise
  = let vs = tyVars t
        (r', sigma) = runState (dpRule r) (fromList (zip vs (map TVar vs)))
    in  substFunc sigma (AFunc f k v t r')

--- Transform a single rule.
dpRule :: ARule TypeExpr -> DPM (ARule TypeExpr)
dpRule (ARule   ty vs e) = ARule ty vs <$> dpExpr e
dpRule e@(AExternal _ _) = return e

--- Transform a single expression.
--- Expressions are transformed in a bottom-up manner, such that the smallest
--- polymorphic expression `e` with type `ty`, where `ty` contains local
--- type variables, is replaced with an explicitly typed expression
--- `e :: ty'`, where `ty'` is `ty` with the local type variables
--- replaced by `()` (the default type).
--- Because the substitutions made are collected, repetitive substitution
--- of the same type variables is avoided. That is,
---
---     null ([] ++ [])
---
--- gets transformed to
---
---     null (([] :: [()]) ++ [])
---
--- instead of
---
---     null (([] :: [()]) ++ ([] :: [()]))
---
--- because both `[]` expressions share the same local type variable.
dpExpr :: AExpr TypeExpr -> DPM (AExpr TypeExpr)
dpExpr = trExpr var lit cmb lat fre orr cse bra typ
  where
  var ty v        = dflt ty $ AVar ty v
  lit ty l        = dflt ty $ ALit ty l
  cmb ty ct qn es = (AComb ty ct qn <$> sequence es) >>= dflt ty
  lat ty     bs e = let (vs, es) = unzip bs in
                    sequence es >>= \es' ->
                    e >>= \e' ->
                    dflt ty $ ALet ty (zip vs es') e'
  fre ty     vs e = e >>= dflt ty . AFree ty vs
  orr ty    e1 e2 = liftM2 (AOr ty) e1 e2 >>= dflt ty
  cse ty  ct e bs = liftM2 (ACase ty ct) e (sequence bs)
  bra         p e = ABranch p <$> e
  typ ty    e ty2 = ((\e' -> ATyped ty e' ty2) <$> e) >>= dflt ty

--- Check whether the given `TypeExpr` contains new local type variables
--- and replace them with the `defaultType` if necessary.
dflt :: TypeExpr -> AExpr TypeExpr -> DPM (AExpr TypeExpr)
dflt ty e
  = get >>= \sub ->
    let new = filter (\v -> not (v `member` sub)) vs in
    if Prelude.null new
      then return e
      else let sub' = fromList $ zip new (repeat defaultType) in
            modify (union sub') >>
            return (ATyped ty e ty)
  where vs = tyVars ty

--- Retrieve all type variables in a type expression.
tyVars :: TypeExpr -> [TVarIndex]
tyVars = nub . trTypeExpr (:[]) (\_ -> concat) (++) (flip const)
