You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

105 lines
2.8 KiB

open Lang
let fail s = Complice.Utils.failwith (Format.sprintf "type inference: %s" s)
module Inference () = struct
let subst = Hashtbl.create 512
let add_subst t t' = Hashtbl.add subst t t'
let rec unify =
let open Types in
function
| Primitive t, Primitive t' when t = t' ->
()
| Primitive t, Variable v | Variable v, Primitive t ->
add_subst (Variable v) (Primitive t)
| Arrow (t1, t1'), Arrow (t2, t2') ->
unify (t1, t2) ;
unify (t1', t2')
| Variable v, Variable v' ->
add_subst (Variable v) (Variable v')
| _ ->
fail "can't unify (unsatisfiable constraint)"
let infered = Hashtbl.create 512
let mk_fresh =
let count = ref (-1) in
fun x ->
incr count ;
let res = Types.Variable (Format.sprintf "_t%d" !count) in
Hashtbl.add infered x res ; res
let literal = function Unit -> Types.Unit | Bool _ -> Types.Bool
let const = function
| Literal l ->
Types.Primitive (literal l)
| Var x -> (
try Hashtbl.find infered x with Not_found -> mk_fresh x )
let rec expr = function
| Const c ->
const c
| Bind (p, e, e') ->
let t = expr e in
Hashtbl.add infered p t ; expr e'
| Abstract (p, e) ->
let t = expr e in
Types.Arrow (Hashtbl.find infered p, t)
| Apply (e, e') -> (
let t = expr e in
let t' = expr e' in
match t with
| Types.Arrow (t_in, t_out) ->
unify (t_in, t') ;
t_out
| _ ->
fail
(Format.asprintf
"%a has type %a, it is not a function, it can't be applied"
Pp.fprintf_expr e Pp.fprintf_type t) )
let file f =
let res = expr f in
res
end
let file f =
let module M = Inference () in
let res = ref (M.file f) in
let keep_on = ref true in
let check_cycle orig x =
let rec aux x =
match Hashtbl.find_opt M.subst x with
| None ->
()
| Some y ->
(* TODO: print the list of ids in which orig appears... *)
if y = orig then
fail
(Format.asprintf
"type %a is recursive, stop doing this please... ( \
https://youtu.be/mqA2evDu4Mw )"
Pp.fprintf_type orig) ;
aux y
in
aux x
in
Hashtbl.iter (fun k _ -> check_cycle k k) M.subst ;
while !keep_on do
keep_on := false ;
Hashtbl.iter
(fun old_type new_type ->
res := Types.subst old_type new_type !res ;
Hashtbl.iter
(fun var var_type ->
let new_var_type = Types.subst old_type new_type var_type in
if var_type <> new_var_type then (
Hashtbl.replace M.infered var new_var_type ;
keep_on := true ))
M.infered)
M.subst
done ;
(M.infered, !res)