(**************************************************************************)
(* Copyright (c) 2010, Romain BARDOU                                      *)
(* All rights reserved.                                                   *)
(*                                                                        *)
(* Redistribution and  use in  source and binary  forms, with  or without *)
(* modification, are permitted provided that the following conditions are *)
(* met:                                                                   *)
(*                                                                        *)
(* * Redistributions  of  source code  must  retain  the above  copyright *)
(*   notice, this list of conditions and the following disclaimer.        *)
(* * Redistributions in  binary form  must reproduce the  above copyright *)
(*   notice, this list of conditions  and the following disclaimer in the *)
(*   documentation and/or other materials provided with the distribution. *)
(* * Neither the  name of Capucine nor  the names of its contributors may *)
(*   be used  to endorse or  promote products derived  from this software *)
(*   without specific prior written permission.                           *)
(*                                                                        *)
(* THIS SOFTWARE  IS PROVIDED BY  THE COPYRIGHT HOLDERS  AND CONTRIBUTORS *)
(* "AS  IS" AND  ANY EXPRESS  OR IMPLIED  WARRANTIES, INCLUDING,  BUT NOT *)
(* LIMITED TO, THE IMPLIED  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR *)
(* A PARTICULAR PURPOSE  ARE DISCLAIMED. IN NO EVENT  SHALL THE COPYRIGHT *)
(* OWNER OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *)
(* SPECIAL,  EXEMPLARY,  OR  CONSEQUENTIAL  DAMAGES (INCLUDING,  BUT  NOT *)
(* LIMITED TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, *)
(* DATA, OR PROFITS; OR BUSINESS  INTERRUPTION) HOWEVER CAUSED AND ON ANY *)
(* THEORY OF  LIABILITY, WHETHER IN  CONTRACT, STRICT LIABILITY,  OR TORT *)
(* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING  IN ANY WAY OUT OF THE USE *)
(* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.   *)
(**************************************************************************)

open Misc
open Ast
open Unify

(* we do not check the "r" in "R.r", as no permission on "R.r" can be produced
   if "r" does not exist anyway, but it should be done someday *)
let rec check_region loc env = function
  | RVar id ->
      check loc (Env.mem_region env id) "Unbound region variable: %s" id
  | RUVar _ ->
      ()
  | RSub (r, _) ->
      check_region loc env r

let check_class loc env c =
  check loc (Env.mem_class env c) "Unbound pointer class: %s" c

let check_param_lengths loc env c rl tl =
  let cd = Env.get_class env c in
  let a = List.length cd.c_type_params in
  let b = List.length tl in
  check loc (a = b)
    "Class %s expects %d type parameters, but is here applied to %d"
    c a b;
  let a = List.length cd.c_region_params in
  let b = List.length rl in
  check loc (a = b)
    "Class %s expects %d region parameters, but is here applied to %d"
    c a b

let rec typ env x =
  let t = match x.node with
    | TEIdent "unit" -> TBase TUnit
    | TEIdent "int" -> TBase TInt
    | TEIdent "bool" -> TBase TBool
    | TEIdent id ->
        check x.loc
          (Env.mem_type env id || Env.mem_type_var env id)
          "Unbound type variable: %s" id;
        TVar id
    | TETuple l ->
        TTuple (List.map (typ env) l)
    | TESum (a, b) ->
        TSum (typ env a, typ env b)
    | TEPointer ((rl, tl, c), r) ->
        check_class x.loc env c;
        List.iter (check_region x.loc env) rl;
        check_region x.loc env r;
        check_param_lengths x.loc env c rl tl;
        TPointer (rl, List.map (typ env) tl, c, r)
  in
  x.typ <- t;
  t

(* substitution on TVar *)
let rec subst_type id t = function
  | TVar v when id = v ->
      t
  | TBase _ | TVar _ | TUVar _ as x ->
      x
  | TTuple tl ->
      TTuple (List.map (subst_type id t) tl)
  | TSum (a, b) ->
      TSum (subst_type id t a, subst_type id t b)
  | TPointer (rl, tl, c, r) ->
      TPointer (rl, List.map (subst_type id t) tl, c, r)

(* substitution on RVar *)
let rec subst_regreg id r = function
  | RVar v when id = v ->
      r
  | RVar _ | RUVar _ as x -> x
  | RSub (a, b) ->
      RSub (subst_regreg id r a, b)

(* substitution on RVar *)
let rec subst_region id r = function
  | TBase _ | TVar _ | TUVar _ as x ->
      x
  | TTuple tl ->
      TTuple (List.map (subst_region id r) tl)
  | TSum (a, b) ->
      TSum (subst_region id r a, subst_region id r b)
  | TPointer (rl, tl, c, r') ->
      TPointer (
        List.map (subst_regreg id r) rl,
        List.map (subst_region id r) tl,
        c,
        subst_regreg id r r'
      )

let pointed_type loc env = function
  | TPointer (rl, tl, c, r) ->
      let c = Env.get_class env c in
      let t =
        List.fold_left2
          (fun acc id t -> subst_type id t acc)
          c.c_type.typ
          c.c_type_params
          tl
      in
      let t =
        List.fold_left2
          (fun acc id r -> subst_region id r acc)
          t
          c.c_region_params
          rl
      in
      let t =
        List.fold_left
          (fun acc id -> subst_region id (RSub (r, id)) acc)
          t
          c.c_owned_regions
      in
      let t =
        Opt.fold
          (fun acc id -> subst_region id r acc)
          t
          c.c_self_region
      in
      t
  | _ ->
      Loc.locate_error loc "This should be a pointer"

let rec instanciate_region regions = function
  | RVar id -> List.assoc id regions
  | RUVar _ -> assert false (* should not happen *)
  | RSub (r, x) -> RSub (instanciate_region regions r, x)

(* replace TVar by fresh TUVar, RVar by fresh RUVar *)
let rec instanciate_type regions vars = function
  | TVar id -> List.assoc id vars
  | TUVar _ -> assert false (* should not happen *)
  | TBase _ as x -> x
  | TTuple x -> TTuple (List.map (instanciate_type regions vars) x)
  | TSum (a, b) ->
      TSum (instanciate_type regions vars a, instanciate_type regions vars b)
  | TPointer (rl, tl, c, r) ->
      TPointer (
        List.map (instanciate_region regions) rl,
        List.map (instanciate_type regions vars) tl,
        c,
        instanciate_region regions r
      )

let rec add_rvar acc = function
  | RVar s -> StringSet.add s acc
  | RUVar _ -> acc
  | RSub (r, _) -> add_rvar acc r

let rec free_regions_of_type acc = function
  | TVar _ | TUVar _ | TBase _ -> acc
  | TPointer (rl, tl, _, r) ->
      let acc = add_rvar acc r in
      let acc = List.fold_left add_rvar acc rl in
      List.fold_left free_regions_of_type acc tl
  | TTuple tl ->
      List.fold_left free_regions_of_type acc tl
  | TSum (a, b) ->
      free_regions_of_type (free_regions_of_type acc a) b

let rec free_variables_of_type acc = function
  | TUVar _ | TBase _ -> acc
  | TVar v -> StringSet.add v acc
  | TPointer (_, tl, _, _)
  | TTuple tl ->
      List.fold_left free_variables_of_type acc tl
  | TSum (a, b) ->
      free_variables_of_type (free_variables_of_type acc a) b

let instanciate_value v =
  (* compute free variables *)
  let types = v.v_return_type :: (List.map snd v.v_params) in
  let types = List.map (fun t -> t.typ) types in
  (* create a substitution to fresh variables *)
  let regions = List.fold_left free_regions_of_type StringSet.empty types in
  let regions = StringSet.elements regions in
  let regions = List.map (fun r -> r, RUVar (fresh ())) regions in
  let vars = List.fold_left free_variables_of_type StringSet.empty types in
  let vars = StringSet.elements vars in
  let vars = List.map (fun t -> t, TUVar (fresh ())) vars in
  (* instanciate *)
  let instanciate_type = instanciate_type regions vars in
  List.map (fun (_, t) -> instanciate_type t.typ) v.v_params,
  instanciate_type v.v_return_type.typ, regions

let rec expr env x =
  let unify = unify_types x.loc false env in
  let unify_sub = unify_types x.loc true env in
  let t = match x.node with
    | Const CUnit -> TBase TUnit
    | Const (CInt _) -> TBase TInt
    | Const (CBool _) -> TBase TBool
    | Unop (op, e) ->
        let t = match op with
          | #int_un_op -> TInt
          | #bool_un_op -> TBool
        in
        let t = TBase t in
        unify (expr env e) t;
        t
    | Binop (op, e1, e2) ->
        let t = match op with
          | #int_bin_op -> TInt
          | #bool_bin_op -> TBool
        in
        let t = TBase t in
        unify (expr env e1) t;
        unify (expr env e2) t;
        t
    | Tuple el ->
        TTuple (List.map (expr env) el)
    | Proj (e, i) ->
        if i > 0 then
          begin match expr env e with
            | TTuple tl ->
                if List.length tl >= i then
                  List.nth tl (i - 1)
                else
                  Loc.locate_error e.loc
                    "This tuple does not have %d components" (List.length tl)
            | _ ->
                Loc.locate_error e.loc "This should be a tuple"
          end
        else
          Loc.locate_error x.loc "Projection indices start from 1"
    | Left e ->
        TSum (expr env e, TUVar (fresh ()))
    | Right e ->
        TSum (TUVar (fresh ()), expr env e)
    | Var id ->
        begin try
          Env.get_var env id
        with Not_found ->
          Loc.locate_error x.loc "Variable not found: %s" id
        end
    | Let (id, e1, e2) ->
        let env2 = Env.add_var env id (expr env e1) in
        expr env2 e2
    | Seq (e1, e2) ->
        unify (expr env e1) (TBase TUnit);
        expr env e2
    | Call (id, el, rs_ref) ->
        begin try
          let v = Env.get_value env id in
          let (ptl, rt, rs) = instanciate_value v in
          let tl = List.map (expr env) el in
          List.iter2 unify tl ptl;
          rs_ref := rs;
          rt
        with Not_found ->
          Loc.locate_error x.loc "Value not found: %s" id
        end
    | If (e1, e2, e3) ->
        unify (expr env e1) (TBase TBool);
        let t = expr env e2 in
        unify (expr env e3) t;
        t
    | While (e1, e2) ->
        unify (expr env e1) (TBase TBool);
        unify (expr env e2) (TBase TUnit);
        TBase TUnit;
    | Assign (e1, e2) ->
        let t = pointed_type e1.loc env (expr env e1) in
        unify_sub (expr env e2) t;
        TBase TUnit
    | Deref e ->
        pointed_type e.loc env (expr env e)
    | New ((rl, tl, c), r) ->
        check_class x.loc env c;
        List.iter (check_region x.loc env) rl;
        check_region x.loc env r;
        check_param_lengths x.loc env c rl tl;
        TPointer (rl, List.map (typ env) tl, c, r)
    | Pack e
    | Unpack e ->
        begin match expr env e with
          | TPointer _ -> TBase TUnit
          | _ -> Loc.locate_error e.loc "This should be a pointer"
        end
    | Adopt (e, r)
    | Focus (e, r) ->
        begin match expr env e with
          | TPointer (rl, tl, c, _) -> TPointer (rl, tl, c, r)
          | _ -> Loc.locate_error e.loc "This should be a pointer"
        end
    | FocusBind (e, vid, rid, body) ->
        begin match expr env e with
          | TPointer (rl, tl, c, r) ->
              let env = Env.add_sub_region env rid r in
              let env = Env.add_var env vid (TPointer (rl, tl, c, RVar rid)) in
              expr env body
          | _ ->
              Loc.locate_error e.loc "This should be a pointer"
        end
    | Unfocus (e, r) ->
        check_region x.loc env r;
        begin match expr env e with
          | TPointer _ -> TBase TUnit
          | _ -> Loc.locate_error e.loc "This should be a pointer"
        end
    | Region (id, e) ->
        let env = Env.add_region env id in
        expr env e
    | Print (_, e) ->
        expr env e
    | UnpackRegion _ | PackRegion _ | UnfocusRegion _ | AdoptRegion _
    | BlackBox _ ->
        TBase TUnit
  in
  x.typ <- t;
  t

let rec exec_prints_expr x =
  match x.node with
    | Const _
    | Var _
    | New _
    | UnpackRegion _
    | PackRegion _
    | UnfocusRegion _
    | AdoptRegion _
    | BlackBox _ ->
        ()

    | Unop (_, e)
    | Proj (e, _)
    | Left e
    | Right e
    | Deref e
    | Pack e
    | Unpack e
    | Adopt (e, _)
    | Focus (e, _)
    | Unfocus (e, _)
    | Region (_, e) ->
        exec_prints_expr e

    | Binop (_, e1, e2)
    | Let (_, e1, e2)
    | Seq (e1, e2)
    | While (e1, e2)
    | Assign (e1, e2)
    | FocusBind (e1, _, _, e2) ->
        exec_prints_expr e1;
        exec_prints_expr e2

    | If (e1, e2, e3) ->
        exec_prints_expr e1;
        exec_prints_expr e2;
        exec_prints_expr e3

    | Tuple el
    | Call (_, el, _) ->
        List.iter exec_prints_expr el

    | Print (s, e) ->
        log "%s: %a@." s Pp.typ (expand_type e.typ)

let class_def env x =
  let env = Env.add_class env x.c_name x in
  let ienv = List.fold_left Env.add_region env x.c_region_params in
  let ienv = Opt.fold Env.add_region ienv x.c_self_region in
  let ienv =
    List.fold_left
      (fun env t -> Env.add_type env t (TVar t))
      ienv
      x.c_type_params
  in
  let ienv = List.fold_left Env.add_region ienv x.c_owned_regions in
  ignore (typ ienv x.c_type);
  env

let rec free_regions_of_type_expr acc x =
  match x.node with
    | TEIdent _ -> acc
    | TEPointer ((rl, tl, _), r) ->
        let acc = add_rvar acc r in
        let acc = List.fold_left add_rvar acc rl in
        List.fold_left free_regions_of_type_expr acc tl
    | TETuple tl ->
        List.fold_left free_regions_of_type_expr acc tl
    | TESum (a, b) ->
        free_regions_of_type_expr (free_regions_of_type_expr acc a) b

let rec free_variables_of_type_expr acc x =
  match x.node with
    | TEIdent ("int" | "unit" | "bool") -> acc (* a little hackish, isn't it *)
    | TEIdent v -> StringSet.add v acc
    | TEPointer ((_, tl, _), _)
    | TETuple tl ->
        List.fold_left free_variables_of_type_expr acc tl
    | TESum (a, b) ->
        free_variables_of_type_expr (free_variables_of_type_expr acc a) b

let value_def env x =
  let env = Env.add_value env x.v_name x in
  let param_types = List.map snd x.v_params in
  (* enter quantified regions in (new) internal environment *)
  let regions = free_regions_of_type_expr StringSet.empty x.v_return_type in
  let regions =
    List.fold_left
      free_regions_of_type_expr
      regions
      param_types
  in
  let regions = StringSet.elements regions in
  let ienv = List.fold_left Env.add_region env regions in
  (* enter quantified type variables in internal environment *)
  let type_vars = free_variables_of_type_expr StringSet.empty x.v_return_type in
  let type_vars =
    List.fold_left
      free_variables_of_type_expr
      type_vars
      param_types
  in
  let type_vars = StringSet.elements type_vars in
  let ienv = List.fold_left Env.add_type_var ienv type_vars in
  (* continue *)
  let return_type = typ ienv x.v_return_type in
  let ienv =
    List.fold_left
      (fun env (id, t) -> Env.add_var env id (typ env t))
      ienv
      x.v_params
  in
  begin match x.v_body with
    | None -> ()
    | Some body ->
        unify_types x.v_return_type.loc true ienv (expr ienv body) return_type;
        exec_prints_expr body
  end;
  log "val %s: typing ok\n%!" x.v_name;
  env

let def env = function
  | Class x -> class_def env x
  | Value x -> value_def env x

let file env x =
  List.fold_left def env x
