(**************************************************************************)
(* Copyright (c) 2010, Romain BARDOU                                      *)
(* All rights reserved.                                                   *)
(*                                                                        *)
(* Redistribution and  use in  source and binary  forms, with  or without *)
(* modification, are permitted provided that the following conditions are *)
(* met:                                                                   *)
(*                                                                        *)
(* * Redistributions  of  source code  must  retain  the above  copyright *)
(*   notice, this list of conditions and the following disclaimer.        *)
(* * Redistributions in  binary form  must reproduce the  above copyright *)
(*   notice, this list of conditions  and the following disclaimer in the *)
(*   documentation and/or other materials provided with the distribution. *)
(* * Neither the  name of Capucine nor  the names of its contributors may *)
(*   be used  to endorse or  promote products derived  from this software *)
(*   without specific prior written permission.                           *)
(*                                                                        *)
(* THIS SOFTWARE  IS PROVIDED BY  THE COPYRIGHT HOLDERS  AND CONTRIBUTORS *)
(* "AS  IS" AND  ANY EXPRESS  OR IMPLIED  WARRANTIES, INCLUDING,  BUT NOT *)
(* LIMITED TO, THE IMPLIED  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR *)
(* A PARTICULAR PURPOSE  ARE DISCLAIMED. IN NO EVENT  SHALL THE COPYRIGHT *)
(* OWNER OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *)
(* SPECIAL,  EXEMPLARY,  OR  CONSEQUENTIAL  DAMAGES (INCLUDING,  BUT  NOT *)
(* LIMITED TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, *)
(* DATA, OR PROFITS; OR BUSINESS  INTERRUPTION) HOWEVER CAUSED AND ON ANY *)
(* THEORY OF  LIABILITY, WHETHER IN  CONTRACT, STRICT LIABILITY,  OR TORT *)
(* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING  IN ANY WAY OUT OF THE USE *)
(* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.   *)
(**************************************************************************)

open Ast
open Bast
open Misc

type c = Bast.region * Bast.region

exception Unification_error
  of (Format.formatter -> unit -> unit) * (Format.formatter -> unit -> unit)

exception Occur_check_error

let unification_error a b = raise (Unification_error (a, b))
let occur_check_error () = raise Occur_check_error

module type UNIFIABLE = sig
  type t
  val unify: bool -> t -> t -> c list
  val subst: Ident.t -> t -> t -> t
  val occur_check: Ident.t -> t -> unit
end

module type SUBST = sig
  type v
  type t
  val create: unit -> t
  val add: t -> Ident.t -> v -> c list
  val apply: t -> v -> v
end

module Subst(X: UNIFIABLE): SUBST with type v = X.t = struct
  type v = X.t
  type t = (Ident.t * X.t) list ref

  let create () =
    ref []

  let rec sep i = function
    | [] ->
        raise Not_found
    | (j, v) :: r when i = j ->
        v, r
    | x :: r ->
        let v, r' = sep i r in
        v, x :: r'

  (* assumes i not in s *)
  let subst_in_subst s i v =
    List.map (fun (j, w) -> j, X.subst i v w) s

  let add s i v =
    X.occur_check i v;
    try
      let v', s' = sep i !s in
      s := (i, v) :: subst_in_subst s' i v;
      X.unify false v v' (* no loop as i is not in v (occur check above) *)
    with Not_found ->
      s := (i, v) :: subst_in_subst !s i v;
      []

  let apply s v =
    List.fold_left (fun v (i, w) -> X.subst i w v) v (List.rev !s)
end

module type UNIFY = sig
  type t
  val unify: bool -> t -> t -> c list
  val subst: Ident.t -> t -> t -> t
  val occur_check: Ident.t -> t -> unit
  val expand: t -> t
end

module rec RSubst: SUBST with type v = Bast.region = Subst(Region)

and Region: UNIFY with type t = Bast.region = struct
  type t = Bast.region

  let subst = RSubst.create ()

  let expand v =
    RSubst.apply subst v

  let rec unify sub a b =
    let unification_error () =
      unification_error
        (fun fmt () -> Pp.region fmt a)
        (fun fmt () -> Pp.region fmt b)
    in
    Pp.print_unique_identifiers := true;
(*    log "@[<hv 2>@[<hv 2>unify regions (sub: %b)@ %a and@ %a@]@ "
      sub Pp.region (expand a) Pp.region (expand b);*)
(*    Pp.print_unique_identifiers := false;*)
    let r = match a, b with
      | RVar a, RVar b when a = b ->
          []
      | RVar a, b when Ident.unifiable a ->
          RSubst.add subst a b
      | b, RVar a when Ident.unifiable a ->
          RSubst.add subst a b
      | RSub (a1, a2), RSub (b1, b2) when a2 = b2 ->
          unify sub a1 b1
      | RVar a, RVar b ->
          (*if sub then*) [ RVar a, RVar b ] (*else unification_error ()*)
      | RSub (_, _), RSub (_, _) ->
          unification_error ()
      | RVar a, RSub (b, c) | RSub (b, c), RVar a ->
          (*unification_error ()*)
          [ RVar a, RSub (b, c) ]
    in
(*    log "@]@ ";*)
    r

  let rec occur_check i = function
    | RVar j when i = j -> occur_check_error ()
    | RVar _ -> ()
    | RSub (r, _) -> occur_check i r

  let rec subst i v = function
    | RVar j when i = j -> v
    | RVar _ as x -> x
    | RSub (r, x) -> RSub (subst i v r, x)
end

module rec TSubst: SUBST with type v = Bast.type_expr = Subst(Type)

and Type: UNIFY with type t = Bast.type_expr = struct
  type t = Bast.type_expr

  let subst = TSubst.create ()

  let expand v =
    TSubst.apply subst v

  let rec unify sub a b =
    let unification_error () =
      unification_error
        (fun fmt () -> Pp.type_expr fmt a)
        (fun fmt () -> Pp.type_expr fmt b)
    in
    Pp.print_unique_identifiers := true;
(*    log "@[<hv 2>@[<hv 2>unify types (sub: %b)@ %a and@ %a@]@ "
      sub Pp.type_expr (expand a) Pp.type_expr (expand b);*)
(*    Pp.print_unique_identifiers := false;*)
    let unify = unify sub in
    let unify_regions = Region.unify sub in
    let r = match a.node, b.node with
      | TEIdent a, TEIdent b when a = b ->
          []
      | TEIdent a, _ when Ident.unifiable a ->
          TSubst.add subst a b
      | _, TEIdent b when Ident.unifiable b ->
          TSubst.add subst b a
      | TETuple a, TETuple b ->
          if List.length a <> List.length b then
            unification_error ()
          else
            List.flatten (List.map2 unify a b)
      | TELogic (tla, ida), TELogic (tlb, idb) ->
          if ida <> idb then unification_error ();
          List.flatten (List.map2 unify tla tlb)
      | TEBase a, TEBase b when a = b ->
          []
      | TEPointer ((rla, tla, ca), ra), TEPointer ((rlb, tlb, cb), rb) ->
          if ca <> cb then unification_error ();
          let cr = List.map2 unify_regions rla rlb in
          let ct = List.map2 unify tla tlb in
          let c = unify_regions ra rb in
          List.flatten (c :: cr @ ct)
      | TESum (a1, b1), TESum (a2, b2) ->
          unify a1 a2 @ unify b1 b2
      | _ -> unification_error ()
    in
(*    log "@]@ ";*)
    r

  let rec occur_check i t =
    match t.node with
      | TEIdent j when i = j ->
          occur_check_error ()
      | TEIdent _ | TEBase _ ->
          ()
      | TETuple l
      | TEPointer ((_, l, _), _)
      | TELogic (l, _) ->
          List.iter (occur_check i) l
      | TESum (a, b) ->
          occur_check i a;
          occur_check i b

  let rec subst i v t =
    match t.node with
      | TEIdent j when i = j ->
          v
      | TEIdent _ | TEBase _ ->
          t
      | TETuple l ->
          { t with node = TETuple (List.map (subst i v) l) }
      | TESum (a, b) ->
          { t with node = TESum (subst i v a, subst i v b) }
      | TEPointer ((rl, tl, c), r) ->
          { t with node = TEPointer ((rl, List.map (subst i v) tl, c), r) }
      | TELogic (l, id) ->
          { t with node = TELogic (List.map (subst i v) l, id) }
end

(*
(* strict *)
let rec is_sub_region env a b =
  match a, b with
    | RVar a, _ when Env.get_region_parent env a = Some b ->
        true
    | RSub (a1, a2), RSub (b1, b2) when a2 = b2 ->
        is_sub_region env a1 b1
    | _ ->
        false
*)

let expand_region = Region.expand

let () = Pp.expand_region_ref := expand_region

let rec expand_regions_in_type t =
  match t.node with
    | TEIdent _ | TEBase _ ->
        t
    | TETuple l ->
        { t with node = TETuple (List.map expand_regions_in_type l) }
    | TEPointer ((rl, tl, c), r) ->
        let node = TEPointer (
          (List.map expand_region rl,
           List.map expand_regions_in_type tl,
           c),
          expand_region r
        ) in
        { t with node = node }
    | TESum (a, b) ->
        { t with node =
            TESum (expand_regions_in_type a, expand_regions_in_type b) }
    | TELogic (tl, id) ->
        { t with node = TELogic (List.map expand_regions_in_type tl, id) }

let expand_type t = expand_regions_in_type (Type.expand t)

let expand_permission = function
  | PEmpty r -> PEmpty (expand_region r)
  | POpen r -> POpen (expand_region r)
  | PClosed r -> PClosed (expand_region r)
  | PGroup r -> PGroup (expand_region r)
  | PArrow (s, r) -> PArrow (expand_region s, expand_region r)
  | PSub (s, r) -> PSub (expand_region s, expand_region r)

let unify_regions loc sub a b =
  try
    Region.unify sub a b
  with Unification_error (x, y) ->
    Pp.print_unique_identifiers := true;
(*    log "@.";*)
    Loc.locate_error loc "@.@[<hv>@[<hv 2>unification error: regions@ %a and@ %a@]@ @[<hv 2>details:@ %a and@ %a@]@]"
      Pp.region (expand_region a) Pp.region (expand_region b) x () y ()

let unify_types loc sub a b =
  try
    Type.unify sub a b
  with Unification_error (x, y) ->
    Pp.print_unique_identifiers := true;
(*    log "@.";*)
    Loc.locate_error loc "@[<hv>@[<hv 2>unification error: types@ %a and@ %a@]@ @[<hv 2>details:@ %a and@ %a@]@]"
      Pp.type_expr (expand_type a) Pp.type_expr (expand_type b) x () y ()
