(**************************************************************************)
(* Copyright (c) 2010, Romain BARDOU                                      *)
(* All rights reserved.                                                   *)
(*                                                                        *)
(* Redistribution and  use in  source and binary  forms, with  or without *)
(* modification, are permitted provided that the following conditions are *)
(* met:                                                                   *)
(*                                                                        *)
(* * Redistributions  of  source code  must  retain  the above  copyright *)
(*   notice, this list of conditions and the following disclaimer.        *)
(* * Redistributions in  binary form  must reproduce the  above copyright *)
(*   notice, this list of conditions  and the following disclaimer in the *)
(*   documentation and/or other materials provided with the distribution. *)
(* * Neither the  name of Capucine nor  the names of its contributors may *)
(*   be used  to endorse or  promote products derived  from this software *)
(*   without specific prior written permission.                           *)
(*                                                                        *)
(* THIS SOFTWARE  IS PROVIDED BY  THE COPYRIGHT HOLDERS  AND CONTRIBUTORS *)
(* "AS  IS" AND  ANY EXPRESS  OR IMPLIED  WARRANTIES, INCLUDING,  BUT NOT *)
(* LIMITED TO, THE IMPLIED  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR *)
(* A PARTICULAR PURPOSE  ARE DISCLAIMED. IN NO EVENT  SHALL THE COPYRIGHT *)
(* OWNER OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *)
(* SPECIAL,  EXEMPLARY,  OR  CONSEQUENTIAL  DAMAGES (INCLUDING,  BUT  NOT *)
(* LIMITED TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, *)
(* DATA, OR PROFITS; OR BUSINESS  INTERRUPTION) HOWEVER CAUSED AND ON ANY *)
(* THEORY OF  LIABILITY, WHETHER IN  CONTRACT, STRICT LIABILITY,  OR TORT *)
(* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING  IN ANY WAY OUT OF THE USE *)
(* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.   *)
(**************************************************************************)

open Misc
open Format

exception Type_error of Lang_ast.location * string

let type_error loc s =
  let buf = Buffer.create 128 in
  let fmt = formatter_of_buffer buf in
  kfprintf
    (fun _ ->
       fprintf fmt "@?";
       raise (Type_error (loc, Buffer.contents buf)))
    fmt
    s

type base_type =
  | TInt
  | TBool
  | TUnit

type ('r, 't) typ =
  | TBase of base_type
  | TPointer of 'r
  | TLogicPointer
  | TVar
  | TPoly of Ident.t
  | TLogic of Ident.t * 'r list * 't list

type region =
  | RRoot of Ident.t
  | ROwn of Ident.t * Ident.t
  | RVar

let unify_identifiers a b =
  if Ident.compare a b = 0 then
    ()
  else
    type_error Lang.dummy_location "unification error (%a and %a)"
      Ident.pp a Ident.pp b

module rec Type:
  Unify.UNIFIABLE with type t = (URegion.t, UType.t) typ =
struct
  type t = (URegion.t, UType.t) typ

  let rec unify a b =
    match a, b with
      | TVar, x | x, TVar ->
          x
      | TBase TInt, TBase TInt
      | TBase TBool, TBase TBool
      | TBase TUnit, TBase TUnit ->
          a
      | TPoly id_a, TPoly id_b ->
          unify_identifiers id_a id_b;
          a
      | TPointer ra, TPointer rb ->
          URegion.unify ra rb;
          a
      | TLogicPointer, TLogicPointer ->
          a
      | TPointer _, TLogicPointer ->
          a
      | TLogicPointer, TPointer _ ->
          b
      | TLogic (n1, r1, a1), TLogic (n2, r2, a2) when Ident.compare n1 n2 = 0 ->
          if List.length r1 <> List.length r2 then
            type_error Lang.dummy_location
              "wrong region argument count for %a (%d / %d)"
              Ident.pp n1
              (List.length r1)
              (List.length r2);
          if List.length a1 <> List.length a2 then
            type_error Lang.dummy_location
              "wrong regular argument count for %a (%d / %d)"
              Ident.pp n1
              (List.length a1)
              (List.length a2);
          List.iter2 URegion.unify r1 r2;
          List.iter2 UType.unify a1 a2;
          a
      | _ ->
          type_error Lang.dummy_location "unification error (types)"
end

and UType: Unify.UNIFY with type data = Type.t = Unify.Make(Type)

and Region:
  Unify.UNIFIABLE with type t = region =
struct
  type t = region

  let rec unify a b =
    match a, b with
      | RVar, x | x, RVar ->
          x
      | RRoot id_a, RRoot id_b ->
          unify_identifiers id_a id_b;
          a
      | ROwn (var_a, id_a), ROwn (var_b, id_b) ->
          unify_identifiers var_a var_b;
          unify_identifiers id_a id_b;
          a
      | _ ->
          type_error Lang.dummy_location "unification error (regions)"
end

and URegion: Unify.UNIFY with type data = Region.t = Unify.Make(Region)

let pp_region fmt r =
  match URegion.find r with
    | RRoot id -> Ident.pp fmt id
    | ROwn (var, id) -> fprintf fmt "%a.%a" Ident.pp var Ident.pp id
    | RVar -> fprintf fmt "?%d" (URegion.uid r)

let rec pp_type fmt t =
  match UType.find t with
    | TBase TInt -> fprintf fmt "int"
    | TBase TBool -> fprintf fmt "bool"
    | TBase TUnit -> fprintf fmt "unit"
    | TPointer r -> fprintf fmt "[%a]" pp_region r
    | TLogicPointer -> fprintf fmt "[?]"
    | TVar -> fprintf fmt "?%d" (UType.uid t)
    | TPoly id -> fprintf fmt "'%a" Ident.pp id
    | TLogic (id, r, a) ->
        fprintf fmt "@[<hov 2>%a" Ident.pp id;
        begin
          match r with
            | [] ->
                ()
            | x :: r ->
                fprintf fmt "@ [@[<hov 2>%a" pp_region x;
                List.iter (fprintf fmt ",@ %a" pp_region) r;
                fprintf fmt "@]]"
        end;
        begin
          match a with
            | [] ->
                ()
            | x :: a ->
                fprintf fmt "@ (@[<hov 2>%a" pp_type x;
                List.iter (fprintf fmt ",@ %a" pp_type) a;
                fprintf fmt "@])"
        end;
        fprintf fmt "@]"

let tunit = UType.create (TBase TUnit)
let tint = UType.create (TBase TInt)
let tbool = UType.create (TBase TBool)
