Raw File
CGen.hs
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC "-fmax-pmcheck-iterations=10000000" #-}

--------------------------------------------------------------------------------
-- | This module generates refinement type constraints
-- | (see Cosman and Jhala, ICFP '17)
--------------------------------------------------------------------------------

module Language.Mist.CGen
  ( generateConstraints
  , NNF (..)
  ) where

import Language.Mist.Types
import Language.Mist.Names
import Data.Bifunctor (second)

-------------------------------------------------------------------------------
-- Data Structures
-------------------------------------------------------------------------------
type Env r a = [(Id, RType r a)]

type CGenConstraints r a = (Predicate r, Show r, Show a, PPrint r, Eq r)

-------------------------------------------------------------------------------
-- | generateConstraints is our main entrypoint to this module
-------------------------------------------------------------------------------
generateConstraints :: CGenConstraints r a => ElaboratedExpr r a -> NNF r
generateConstraints = fst . runFresh . strengthening [] TUnit

isApplicationForm :: ElaboratedExpr r a -> Bool
isApplicationForm App{} = True
isApplicationForm TApp{} = True
isApplicationForm Id{} = True
isApplicationForm Prim{} = True
isApplicationForm Number{} = True
isApplicationForm Boolean{} = True
isApplicationForm Unit{} = True
isApplicationForm _ = False

strengthening :: CGenConstraints r a =>
        Env r a -> Type -> ElaboratedExpr r a -> Fresh (NNF r, RType r a)
strengthening env t e = do
  (c, outT) <- _strengthening env t e
  -- !_ <- traceM $ "strengthening: " <> "⊢ " <> pprint e <> " <= " <> pprint t <> " ↑ " <> pprint outT
  pure (c, outT)

_strengthening :: CGenConstraints r a =>
        Env r a -> Type -> ElaboratedExpr r a -> Fresh (NNF r, RType r a)
_strengthening env _ e
  | isApplicationForm e = synth env e
_strengthening _ _ (Lam (AnnBind _ (Just (ElabAssume tx)) _) _ _)
  = pure (CAnd [], tx)
_strengthening env (_ :=> te) (Lam (AnnBind x (Just (ElabUnrefined tx)) l) e _) = do
  tHat <- fresh l env tx
  (c, t) <- strengthening ((x, tHat):env) te e
  pure (mkAll x tHat c, RFun (Bind x l) tHat t)
_strengthening env (_ :=> te) (Lam (AnnBind x (Just (ElabRefined tx)) loc) e _) = do
  (c, t) <- strengthening ((x, tx):env) te e
  pure (mkAll x tx c, RFun (Bind x loc) tx t)
_strengthening _ _ (Lam _ _ _)
  = error "should not occur"
_strengthening env (TForall _ typ) (TAbs tvar e _) = do
  (c, t) <- strengthening env typ e
  pure (c, RForallP tvar t)
_strengthening env typ e = do
  tHat <- fresh (extractLoc e) env typ
  c <- check env e tHat
  pure (c,tHat)

synth :: CGenConstraints r a =>
        Env r a -> ElaboratedExpr r a -> Fresh (NNF r, RType r a)
synth env e = do
  (c, outT) <- _synth env e
  -- !_ <- traceM $ "synth: " <> "⊢ " <> pprint e <> " => " <> pprint outT
  pure (c, outT)

_synth :: CGenConstraints r a =>
        Env r a -> ElaboratedExpr r a -> Fresh (NNF r, RType r a)
_synth _ e@Unit{}    = (Head true,) <$> prim e
_synth _ e@Number{}  = (Head true,) <$> prim e
_synth _ e@Boolean{} = (Head true,) <$> prim e
_synth _ e@Prim{}    = (Head true,) <$> prim e
_synth env (Id x _)  = (Head true,) <$> single env x

_synth env (App e (Id y loc) _) = do
  (c, t) <- synth env e
  (c', t') <- appSynth env t y loc
  pure (CAnd [c, c'], t')

_synth _ App{} = error "argument is non-variable"

_synth env (TApp e typ loc) = do
  (c, scheme ) <- synth env e
  case scheme of
     RForall (TV alpha) t -> do
        tHat <- stale loc typ
        pure (c, substReftReft (alpha |-> tHat) t)
     RForallP (TV alpha) t -> do
        tHat <- fresh loc env typ
        pure (c, substReftReft (alpha |-> tHat) t)
     _ -> error "TApp on not-a-scheme"

_synth _ _ = error "Internal Error: Synth called on non-application form"

stale :: CGenConstraints r a => a -> Type -> Fresh (RType r a)
stale loc (TVar alpha) = do
  x <- refreshId $ "staleArg" ++ cSEPARATOR
  pure $ RBase (Bind x loc) (TVar alpha) true
stale loc TUnit = staleBaseType loc TUnit
stale loc TInt = staleBaseType loc TInt
stale loc TBool = staleBaseType loc TBool
stale loc TSet = staleBaseType loc TSet
stale loc (typ1 :=> typ2) = do
  rtype1 <- stale loc typ1
  x <- refreshId $ "staleArg" ++ cSEPARATOR
  rtype2 <- stale loc typ2
  pure $ RFun (Bind x loc) rtype1 rtype2
stale l (TCtor ctor types) = do -- TODO: reduce duplication with staleBaseType
  v <- refreshId $ "VV" ++ cSEPARATOR
  appType <- RApp ctor <$> mapM (sequence . second (stale l)) types
  pure $ RRTy (Bind v l) appType true
stale l (TForall tvar typ) = RForall tvar <$> stale l typ

staleBaseType :: (Predicate r) => a -> Type -> Fresh (RType r a)
staleBaseType l baseType = do
  v <- refreshId $ "VV" ++ cSEPARATOR
  pure $ RBase (Bind v l) baseType true

appSynth :: CGenConstraints r a => Env r a -> RType r a -> Id -> a -> Fresh (NNF r, RType r a)
appSynth env t y loc = do
  (c, outT) <- _appSynth env t y loc
  -- !_ <- traceM $ "appSynth: " <> pprint t <> " ⊢ " <> y <> " >> " <> pprint outT
  pure (c, outT)

-- | env | tfun ⊢ y >>
_appSynth :: CGenConstraints r a => Env r a -> RType r a -> Id -> a -> Fresh (NNF r, RType r a)
_appSynth env (RFun x tx t) y loc = do
  c <- check env (Id y loc) tx
  pure (c, substReftPred (bindId x |-> y) t)

_appSynth env (RIFun (Bind x _) tx t) y loc = do
  z <- refreshId x
  (c, t') <- appSynth ((z, tx):env) (substReftPred (x |-> z) t) y loc
  tHat <- fresh loc env (eraseRType t')
  c' <- t' <: tHat
  pure (mkExists z tx (CAnd [c, c']), tHat)

_appSynth _ _ _ _ = error "Applying non function"

single :: CGenConstraints r a => Env r a -> Id -> Fresh (RType r a)
single env x = case flattenRType <$> lookup x env of
  Just (RBase (Bind y l) baseType reft) -> do
  -- `x` is already bound, so instead of "re-binding" x we should selfify
  -- (al la Ou et al. IFIP TCS '04)
    v <- refreshId $ "VV" ++ cSEPARATOR
    pure $ RBase (Bind v l) baseType (strengthen (subst (y |-> v) reft) (varsEqual v x))
  Just (RRTy (Bind y l) rt reft) -> do
    v <- refreshId $ "VV" ++ cSEPARATOR
    pure $ RRTy (Bind v l) rt (strengthen (subst (y |-> v) reft) (varsEqual v x))
  Just rt -> pure rt
  Nothing -> error $ "Unbound Variable " ++ show x ++ show env


check :: CGenConstraints r a => Env r a -> ElaboratedExpr r a -> RType r a -> Fresh (NNF r)
check env e t = do
  -- !_ <- traceM $ "check: " <> "⊢ " <> pprint e <> " <= " <> pprint t
  _check env e t

_check :: CGenConstraints r a => Env r a -> ElaboratedExpr r a -> RType r a -> Fresh (NNF r)
_check env (Let b e1 e2 _) t2
  -- Annotated with an assume
  | (AnnBind x (Just (ElabAssume tx)) _) <- b = do
    c <- check ((x, tx):env) e2 t2
    pure $ mkAll x tx c

  -- Annotated with an RType (Implicit Parameter)
  | (AnnBind x (Just (ElabRefined rt@RIFun{})) _) <- b = do
    let (ns, tx) = splitImplicits rt
    c1 <- check (ns ++ ((x, rt):env)) e1 tx
    let c1' = foldr (\(z, tz) c -> mkAll z tz c) c1 ns
    c2 <- check ((x, rt):env) e2 t2
    pure $ mkAll x rt (CAnd [c1', c2])

  -- Annotated with an RType
  | (AnnBind x (Just (ElabRefined tx)) _) <- b = do
    c1 <- check ((x, tx):env) e1 tx
    c2 <- check ((x, tx):env) e2 t2
    pure $ mkAll x tx (CAnd [c1, c2])

  -- Unrefined
  -- Not allowed to be recursive
  | (AnnBind x (Just (ElabUnrefined typ)) _loc) <- b = do
    (c1, t1) <- strengthening env typ e1
    c2 <- check ((x, t1):env) e2 t2
    let c' = mkAll x t1 c2
    pure $ CAnd [c1, c']

  | (AnnBind _ Nothing _) <- b = error "INTERNAL ERROR: annotation missing on let"

-- TODO: check that this rule is correct. In particular the interaction of rebindInEnv generated constraints and mkAll. I think it should be fine due to how single works.
_check env (If (Id y _) e1 e2 _) t2 = do
  idT <- refreshId ("then:"<>y)
  idF <- refreshId ("else:"<>y)
  c1 <- check env e1 t2
  c2 <- check env e2 t2
  pure $ CAnd [All idT TBool (var y) c1,
               All idF TBool (varNot y) c2]
_check _ If{} _ = error "INTERNAL ERROR: if not in ANF"

_check _ (Lam (AnnBind _ (Just (ElabAssume _)) _) _ _) _ = pure (CAnd [])
_check env (Lam (AnnBind x _ _) e _) (RFun y ty t) =
  mkAll x ty <$> check ((x, ty):env) e (substReftPred (bindId y |-> x) t)

_check env (TAbs tvar' e loc) (RForallP (TV tvar) t)  = do
  stalert <- stale loc (TVar tvar')
  check env e (substReftReft (tvar |-> stalert) t)
_check env (TAbs tvar' e _) (RForall (TV tvar) t) =
  check env e (substReftType (tvar |-> TVar tvar') t)

_check env (App e (Id y loc) _) t = do
  (c, t') <- synth env e
  c' <- appCheck env t' y loc t
  pure $ CAnd [c, c']

_check env e t = do
  (c, t') <- strengthening env (eraseRType t) e
  cSub <- t' <: t
  pure $ CAnd [c, cSub]

appCheck :: CGenConstraints r a => Env r a -> RType r a -> Id -> a -> RType r a -> Fresh (NNF r)
appCheck env t y loc t' = do
  -- !_ <- traceM $ "appCheck: " <> pprint t <> " ⊢ " <> y <> " << " <> pprint t'
  _appCheck env t y loc t'

-- | env | t ⊢ y << t'
_appCheck :: CGenConstraints r a => Env r a -> RType r a -> Id -> a -> RType r a -> Fresh (NNF r)
_appCheck env (RFun (Bind x _) tx t) y loc t' = do
  c <- check env (Id y loc) tx
  cSub <- substReftPred (x |-> y) t <: t'
  pure $ CAnd [c, cSub]

_appCheck env (RIFun (Bind x _) tx t) y loc t' = do
  z <- refreshId x
  c <- appCheck ((z, tx):env) (substReftPred (x |-> z) t) y loc t'
  pure $ mkExists z tx c

_appCheck _ _ _ _ _ = error "application at non function type"

fresh :: CGenConstraints r a => a -> Env r a -> Type -> Fresh (RType r a)
fresh l _ (TVar alpha) = do
  x <- refreshId $ "karg" ++ cSEPARATOR
  pure $ RBase (Bind x l) (TVar alpha) true
fresh loc env TUnit = freshBaseType loc env TUnit
fresh loc env TInt = freshBaseType loc env TInt
fresh loc env TBool = freshBaseType loc env TBool
fresh loc env TSet = freshBaseType loc env TSet
fresh loc env (typ1 :=> typ2) = do
  rtype1 <- fresh loc env typ1
  x <- refreshId $ "karg" ++ cSEPARATOR
  rtype2 <- fresh loc ((x,rtype1):env) typ2
  pure $ RFun (Bind x loc) rtype1 rtype2
fresh l env (TCtor ctor types) = do -- TODO: reduce duplication with freshBaseType
  kappa <- refreshId $ "kvar" ++ cSEPARATOR
  v <- refreshId $ "VV" ++ cSEPARATOR
  let k = buildKvar kappa $ v : map fst (foTypes (eraseRTypes env))
  appType <- RApp ctor <$> mapM (sequence . second (fresh l env)) types
  pure $ RRTy (Bind v l) appType k
fresh l env (TForall tvar typ) = RForall tvar <$> fresh l env typ

freshBaseType :: (Predicate r) => a -> Env r a -> Type -> Fresh (RType r a)
freshBaseType l env baseType = do
  kappa <- refreshId $ "kvar" ++ cSEPARATOR
  v <- refreshId $ "VV" ++ cSEPARATOR
  let k = buildKvar kappa $ v : map fst (foTypes (eraseRTypes env))
  pure $ RBase (Bind v l) baseType k

-- filters out higher-order type binders in the environment
-- TODO(Matt): check that this is the correct behavior
foTypes :: [(Id, Type)] -> [(Id, Type)]
foTypes ((x, t@TVar{}):xs) = (x, t):foTypes xs
foTypes ((x, t@TUnit{}):xs) = (x, t):foTypes xs
foTypes ((x, t@TInt{}):xs) = (x, t):foTypes xs
foTypes ((x, t@TBool{}):xs) = (x, t):foTypes xs
-- foTypes ((_,TSet):xs) = foTypes xs
foTypes ((x,t@TSet):xs) = (x,t):foTypes xs
foTypes ((_, _ :=> _):xs) = foTypes xs
foTypes ((_, TCtor{}):xs) = foTypes xs
foTypes ((_, TForall{}):xs) = foTypes xs
foTypes [] = []

eraseRTypes :: Env r a -> [(Id, Type)]
eraseRTypes = map (\(id, rtype) -> (id, eraseRType rtype))

(<:) :: CGenConstraints r a => RType r a -> RType r a -> Fresh (NNF r)
rtype1 <: rtype2 = do
  c <- rtype1 <<: rtype2
  -- !_ <- traceM $ "subtyping: " <> "⊢ " <> pprint rtype1 <> "\n\t<: " <> pprint rtype2 <> "\n\t⊣  " <> show c <> "\n\n"
  pure c

(<<:) :: CGenConstraints r a => RType r a -> RType r a -> Fresh (NNF r)
rtype1 <<: rtype2 = go (flattenRType rtype1) (flattenRType rtype2)
  where
    go (RBase (Bind x1 _) b1 p1) (RBase (Bind x2 _) b2 p2)
      | b1 == b2 = pure $ All x1 b1 p1 (Head $ varSubst (x2 |-> x1) p2)
      | otherwise = error $ "error?" ++ show b1 ++ show b2
    go (RFun (Bind x1 _) t1 t1') (RFun (Bind x2 _) t2 t2') = do
      c <- t2 <: t1
      c' <- substReftPred (x1 |-> x2) t1' <: t2'
      pure $ CAnd [c, mkAll x2 t2 c']
    go (RForall alpha t1) (RForall beta t2)
      | alpha == beta = t1 <: t2
      | otherwise = error "Constraint generation subtyping error"
    go (RForallP alpha t1) (RForallP beta t2)
      | alpha == beta = t1 <: t2
      | otherwise = error "Constraint generation subtyping error"
    go (RApp c1 vts1) (RApp c2 vts2)
      | c1 == c2  = CAnd <$> sequence (concat $ zipWith constructorSub vts1 vts2)
      | otherwise = error "CGen: constructors don't match"
    go (RIFun (Bind x _) t1 t1') t2 = do
      z <- refreshId x
      cSub <- substReftPred (x |-> z) t1' <: t2
      pure $ mkExists z t1 cSub
    go (RRTy (Bind x1 _) rt1 p1) (RRTy (Bind x2 _) rt2 p2) = do
      let outer = All x1 (eraseRType rt1) p1 (Head $ varSubst (x2 |-> x1) p2)
      inner <- rt1 <: rt2
      pure $ CAnd [outer, inner]
    go rt1 (RRTy (Bind x _) rt2 reft) = do
      let (y,r) = rtsymreft rt1
      let subreft = All y (eraseRType rt1) r (Head $ varSubst (x |-> y) reft)
      inner <- rt1 <: rt2
      pure $ CAnd [inner, subreft]
    go (RRTy (Bind x _) rt1 reft) rt2 = All x (eraseRType rt1) reft <$> (rt1 <: rt2)
    go _ _ = error $ "CGen subtyping error. Got:\n\n" ++ pprint rtype1 ++ "\n\nbut expected:\n\n" ++ pprint rtype2 ++ "\n" ++ "i.e. Got:\n\n" ++ pprint (eraseRType rtype1) ++ "\n\nbut expected:\n\n" ++ pprint (eraseRType rtype2) ++ "\n"

rtsymreft (RBase (Bind x _) _ r) = (x,r)
rtsymreft (RRTy (Bind x _) _ r) = (x,r)
rtsymreft _ = ("_",true)

(v, rt1) `constructorSub` (_,rt2) = case v of
                         -- TODO: write tests that over these two cases...
                         Invariant -> []
                         Bivariant -> [rt1 <: rt2, rt2 <: rt1]
                         Covariant -> [rt1 <: rt2]
                         Contravariant -> [rt2 <: rt1]

flattenRType :: CGenConstraints r a => RType r a -> RType r a
flattenRType rrty@(RRTy (Bind x _) rt p)
  | (RBase by typ p') <- rt = RBase by typ (strengthen p' (varSubst (x |-> bindId by) p))
  | (RRTy by rt' p') <- rt = flattenRType (RRTy by rt' (strengthen p' (varSubst (x |-> bindId by) p)))
  | otherwise = rrty
flattenRType rt = rt


-- TODO: RApp?
-- | (x :: t) => c
mkAll :: CGenConstraints r a => Id -> RType r a -> NNF r -> NNF r
mkAll x rt c = case flattenRType rt of
   RBase (Bind y _) b p -> All x b (varSubst (y |-> x) p) c
   RRTy (Bind y _) rt p -> All x (eraseRType rt) (varSubst (y |-> x) p) c
   _ -> c

-- | ∃ x :: t. c
mkExists x rt c = case flattenRType rt of
             RBase (Bind y _) b p -> Any x b (varSubst (y |-> x) p) c
             RRTy (Bind y _) rt p -> Any x (eraseRType rt) (varSubst (y |-> x) p) c
             _ -> c

splitImplicits :: RType r a -> ([(Id, RType r a)], RType r a)
splitImplicits (RIFun b t t') = ((bindId b,t):bs, t'')
    where (bs,t'') = splitImplicits t'
splitImplicits rt = ([],rt)
back to top