type cls = Kw | Kl | Ks | Kd
type op_base =
  | Oadd
  | Osub
  | Omul
  | Oor
  | Oshl
  | Oshr
type op = cls * op_base

let op_bases =
  [Oadd; Osub; Omul; Oor; Oshl; Oshr]

let commutative = function
  | (_, (Oadd | Omul | Oor)) -> true
  | (_, _) -> false

let associative = function
  | (_, (Oadd | Omul | Oor)) -> true
  | (_, _) -> false

type atomic_pattern =
  | Tmp
  | AnyCon
  | Con of int64
(* Tmp < AnyCon < Con k *)

type pattern =
  | Bnr of op * pattern * pattern
  | Atm of atomic_pattern
  | Var of string * atomic_pattern

let is_atomic = function
  | (Atm _ | Var _) -> true
  | _ -> false

let show_op_base o =
  match o with
  | Oadd -> "add"
  | Osub -> "sub"
  | Omul -> "mul"
  | Oor -> "or"
  | Oshl -> "shl"
  | Oshr -> "shr"

let show_op (k, o) =
  show_op_base o ^
  (match k with
   | Kw -> "w"
   | Kl -> "l"
   | Ks -> "s"
   | Kd -> "d")

let rec show_pattern p =
  match p with
  | Atm Tmp -> "%"
  | Atm AnyCon -> "$"
  | Atm (Con n) -> Int64.to_string n
  | Var (v, p) ->
      show_pattern (Atm p) ^ "'" ^ v
  | Bnr (o, pl, pr) ->
      "(" ^ show_op o ^
      " " ^ show_pattern pl ^
      " " ^ show_pattern pr ^ ")"

let get_atomic p =
  match p with
  | (Atm a | Var (_, a)) -> Some a
  | _ -> None

let rec pattern_match p w =
  match p with
  | Var (_, p) ->
      pattern_match (Atm p) w
  | Atm Tmp ->
      begin match get_atomic w with
      | Some (Con _ | AnyCon) -> false
      | _ -> true
      end
  | Atm (Con _) -> w = p
  | Atm (AnyCon) ->
      not (pattern_match (Atm Tmp) w)
  | Bnr (o, pl, pr) ->
      begin match w with
      | Bnr (o', wl, wr) ->
          o' = o &&
          pattern_match pl wl &&
          pattern_match pr wr
      | _ -> false
      end

type +'a cursor = (* a position inside a pattern *)
  | Bnrl of op * 'a cursor * pattern
  | Bnrr of op * pattern * 'a cursor
  | Top of 'a

let rec fold_cursor c p =
  match c with
  | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
  | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
  | Top _ -> p

let peel p x =
  let once out (p, c) =
    match p with
    | Var (_, p) -> (Atm p, c) :: out
    | Atm _ -> (p, c) :: out
    | Bnr (o, pl, pr) ->
        (pl, Bnrl (o, c, pr)) ::
        (pr, Bnrr (o, pl, c)) :: out
  in
  let rec go l =
    let l' = List.fold_left once [] l in
    if List.length l' = List.length l
    then l'
    else go l'
  in go [(p, Top x)]

let fold_pairs l1 l2 ini f =
  let rec go acc = function
    | [] -> acc
    | a :: l1' ->
        go (List.fold_left
          (fun acc b -> f (a, b) acc)
          acc l2) l1'
  in go ini l1

let iter_pairs l f =
  fold_pairs l l () (fun x () -> f x)

let inverse l =
  List.map (fun (a, b) -> (b, a)) l

type 'a state =
  { id: int
  ; seen: pattern
  ; point: ('a cursor) list }

let rec binops side {point; _} =
  List.filter_map (fun c ->
      match c, side with
      | Bnrl (o, c, r), `L -> Some ((o, c), r)
      | Bnrr (o, l, c), `R -> Some ((o, c), l)
      | _ -> None)
    point

let group_by_fst l =
  List.fast_sort (fun (a, _) (b, _) ->
    compare a b) l |>
  List.fold_left (fun (oo, l, res) (o', c) ->
      match oo with
      | None -> (Some o', [c], [])
      | Some o when o = o' -> (oo, c :: l, res)
      | Some o -> (Some o', [c], (o, l) :: res))
    (None, [], []) |>
  (function
    | (None, _, _) -> []
    | (Some o, l, res) -> (o, l) :: res)

let sort_uniq cmp l =
  List.fast_sort cmp l |>
  List.fold_left (fun (eo, l) e' ->
      match eo with
      | None -> (Some e', l)
      | Some e when cmp e e' = 0 -> (eo, l)
      | Some e -> (Some e', e :: l))
    (None, []) |>
  (function
    | (None, _) -> []
    | (Some e, l) -> List.rev (e :: l))

let setify l =
  sort_uniq compare l

let normalize (point: ('a cursor) list) =
  setify point

let next_binary tmp s1 s2 =
  let pm w (_, p) = pattern_match p w in
  let o1 = binops `L s1 |>
           List.filter (pm s2.seen) |>
           List.map fst in
  let o2 = binops `R s2 |>
           List.filter (pm s1.seen) |>
           List.map fst in
  List.map (fun (o, l) ->
      o,
      { id = -1
      ; seen = Bnr (o, s1.seen, s2.seen)
      ; point = normalize (l @ tmp) })
    (group_by_fst (o1 @ o2))

type p = string

module StateSet : sig
  type t
  val create: unit -> t
  val add: t -> p state ->
           [> `Added | `Found ] * p state
  val iter: t -> (p state -> unit) -> unit
  val elems: t -> (p state) list
end = struct
  open Hashtbl.Make(struct
    type t = p state
    let equal s1 s2 = s1.point = s2.point
    let hash s = Hashtbl.hash s.point
  end)
  type nonrec t =
    { h: int t
    ; mutable next_id: int }
  let create () =
    { h = create 500; next_id = 0 }
  let add set s =
    assert (s.point = normalize s.point);
    try
      let id = find set.h s in
      `Found, {s with id}
    with Not_found -> begin
      let id = set.next_id in
      set.next_id <- id + 1;
      add set.h s id;
      `Added, {s with id}
    end
  let iter set f =
    let f s id = f {s with id} in
    iter f set.h
  let elems set =
    let res = ref [] in
    iter set (fun s -> res := s :: !res);
    !res
end

type table_key =
  | K of op * p state * p state

module StateMap = struct
  include Map.Make(struct
      type t = table_key
      let compare ka kb =
        match ka, kb with
        | K (o, sl, sr), K (o', sl', sr') ->
            compare (o, sl.id, sr.id)
                    (o', sl'.id, sr'.id)
    end)
  let invert n sm =
    let rmap = Array.make n [] in
    iter (fun k {id; _} ->
        match k with
        | K (o, sl, sr) ->
            rmap.(id) <-
              (o, (sl.id, sr.id)) :: rmap.(id)
      ) sm;
    Array.map group_by_fst rmap
  let by_ops sm =
    fold (fun tk s ops ->
        match tk with
        | K (op, l, r) ->
            (op, ((l.id, r.id), s.id)) :: ops)
      sm [] |> group_by_fst
end

type rule =
  { name: string
  ; vars: string list
  ; pattern: pattern }

let generate_table rl =
  let states = StateSet.create () in
  let rl =
    (* these atomic patterns must occur in
     * rules so that we are able to number
     * all possible refs *)
    [ { name = "$"; vars = []
      ; pattern = Atm AnyCon }
    ; { name = "%"; vars = []
      ; pattern = Atm Tmp } ] @ rl
  in
  (* initialize states *)
  let ground =
    List.concat_map
      (fun r -> peel r.pattern r.name) rl |>
    group_by_fst
  in
  let tmp = List.assoc (Atm Tmp) ground in
  let con = List.assoc (Atm AnyCon) ground in
  let atoms = ref [] in
  let () =
    List.iter (fun (seen, l) ->
      let point =
        if pattern_match (Atm Tmp) seen
        then normalize (tmp @ l)
        else normalize (con @ l)
      in
      let s = {id = -1; seen; point} in
      let _, s = StateSet.add states s in
      match get_atomic seen with
      | Some atm -> atoms := (atm, s) :: !atoms
      | None -> ()
    ) ground
  in
  (* setup loop state *)
  let map = ref StateMap.empty in
  let map_add k s' =
    map := StateMap.add k s' !map
  in
  let flag = ref `Added in
  let flagmerge = function
    | `Added -> flag := `Added
    | _ -> ()
  in
  (* iterate until fixpoint *)
  while !flag = `Added do
    flag := `Stop;
    let statel = StateSet.elems states in
    iter_pairs statel (fun (sl, sr) ->
      next_binary tmp sl sr |>
      List.iter (fun (o, s') ->
        let flag', s' =
          StateSet.add states s' in
        flagmerge flag';
        map_add (K (o, sl, sr)) s';
    ));
  done;
  let states =
    StateSet.elems states |>
    List.sort (fun s s' -> compare s.id s'.id) |>
    Array.of_list
  in
  (states, !atoms, !map)

let intersperse x l =
  let rec go left right out =
    let out =
      (List.rev left @ [x] @ right) ::
      out in
    match right with
    | x :: right' ->
        go (x :: left) right' out
    | [] -> out
  in go [] l []

let rec permute = function
  | [] -> [[]]
  | x :: l ->
      List.concat (List.map
        (intersperse x) (permute l))

(* build all binary trees with ordered
 * leaves l *)
let rec bins build l =
  let rec go l r out =
    match r with
    | [] -> out
    | x :: r' ->
        go (l @ [x]) r'
          (fold_pairs
            (bins build l)
            (bins build r)
            out (fun (l, r) out ->
                   build l r :: out))
  in
  match l with
  | [] -> []
  | [x] -> [x]
  | x :: l -> go [x] l []

let products l ini f =
  let rec go acc la = function
    | [] -> f (List.rev la) acc
    | xs :: l ->
        List.fold_left (fun acc x ->
            go acc (x :: la) l)
          acc xs
  in go ini [] l

(* combinatorial nuke... *)
let rec ac_equiv =
  let rec alevel o = function
    | Bnr (o', l, r) when o' = o ->
        alevel o l @ alevel o r
    | x -> [x]
  in function
  | Bnr (o, _, _) as p
    when associative o ->
      products
        (List.map ac_equiv (alevel o p)) []
        (fun choice out ->
          List.concat_map
            (bins (fun l r -> Bnr (o, l, r)))
            (if commutative o
              then permute choice
              else [choice]) @ out)
  | Bnr (o, l, r)
    when commutative o ->
      fold_pairs
        (ac_equiv l) (ac_equiv r) []
        (fun (l, r) out ->
          Bnr (o, l, r) ::
          Bnr (o, r, l) :: out)
  | Bnr (o, l, r) ->
      fold_pairs
        (ac_equiv l) (ac_equiv r) []
        (fun (l, r) out ->
          Bnr (o, l, r) :: out)
  | x -> [x]

module Action: sig
  type node =
    | Switch of (int * t) list
    | Push of bool * t
    | Pop of t
    | Set of string * t
    | Stop
  and t = private
    { id: int; node: node }
  val equal: t -> t -> bool
  val size: t -> int
  val stop: t
  val mk_push: sym:bool -> t -> t
  val mk_pop: t -> t
  val mk_set: string -> t -> t
  val mk_switch: int list -> (int -> t) -> t
  val pp: Format.formatter -> t -> unit
end = struct
  type node =
    | Switch of (int * t) list
    | Push of bool * t
    | Pop of t
    | Set of string * t
    | Stop
  and t =
    { id: int; node: node }

  let equal a a' = a.id = a'.id
  let size a =
    let seen = Hashtbl.create 10 in
    let rec node_size = function
      | Switch l ->
          List.fold_left
            (fun n (_, a) -> n + size a) 0 l
      | (Push (_, a) | Pop a | Set (_, a)) ->
          size a
      | Stop -> 0
    and size {id; node} =
      if Hashtbl.mem seen id
      then 0
      else begin
        Hashtbl.add seen id ();
        1 + node_size node
      end
    in
    size a

  let mk =
    let hcons = Hashtbl.create 100 in
    let fresh = ref 0 in
    fun node ->
      let id =
        try Hashtbl.find hcons node
        with Not_found ->
          let id = !fresh in
          Hashtbl.add hcons node id;
          fresh := id + 1;
          id
      in
      {id; node}
  let stop = mk Stop
  let mk_push ~sym a = mk (Push (sym, a))
  let mk_pop a =
    match a.node with
    | Stop -> a
    | _ -> mk (Pop a)
  let mk_set v a = mk (Set (v, a))
  let mk_switch ids f =
    match List.map f ids with
    | [] -> failwith "empty switch";
    | c :: cs as cases ->
        if List.for_all (equal c) cs then c
        else
          let cases = List.combine ids cases in
          mk (Switch cases)

  open Format
  let rec pp_node fmt = function
    | Switch l ->
        fprintf fmt "@[<v>@[<v2>switch{";
        let pp_case (c, a) =
          let pp_sep fmt () = fprintf fmt "," in
          fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
            (pp_print_list ~pp_sep pp_print_int)
            c pp a
        in
        inverse l |> group_by_fst |> inverse |>
          List.iter pp_case;
        fprintf fmt "@]@,}@]"
    | Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
    | Push (false, a) -> fprintf fmt "push@ %a" pp a
    | Pop a -> fprintf fmt "pop@ %a" pp a
    | Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
    | Stop -> fprintf fmt "•"
  and pp fmt a = pp_node fmt a.node
end

(* a state is commutative if (a op b) enters
 * it iff (b op a) enters it as well *)
let symmetric rmap id =
  List.for_all (fun (_, l) ->
      let l1, l2 =
        List.filter (fun (a, b) -> a <> b) l |>
        List.partition (fun (a, b) -> a < b)
      in
      setify l1 = setify (inverse l2))
    rmap.(id)

(* left-to-right matching of a set of patterns;
 * may raise if there is no lr matcher for the
 * input rule *)
let lr_matcher statemap states rules name =
  let rmap =
    let nstates = Array.length states in
    StateMap.invert nstates statemap
  in
  let exception Stuck in
  (* the list of ids represents a class of terms
   * whose root ends up being labelled with one
   * such id; the gen function generates a matcher
   * that will, given any such term, assign values
   * for the Var nodes of one pattern in pats *)
  let rec gen
  : 'a. int list -> (pattern * 'a) list
        -> (int -> (pattern * 'a) list -> Action.t)
        -> Action.t
  = fun ids pats k ->
    Action.mk_switch (setify ids) @@ fun id_top ->
    let sym = symmetric rmap id_top in
    let id_ops =
      if sym then
        let ordered (a, b) = a <= b in
        List.map (fun (o, l) ->
            (o, List.filter ordered l))
          rmap.(id_top)
      else rmap.(id_top)
    in
    (* consider only the patterns that are
     * compatible with the current id *)
    let atm_pats, bin_pats =
      List.filter (function
          | Bnr (o, _, _), _ ->
              List.exists
                (fun (o', _) -> o' = o)
                id_ops
          | _ -> true) pats |>
      List.partition
        (fun (pat, _) -> is_atomic pat)
    in
    try
      if bin_pats = [] then raise Stuck;
      let pats_l =
        List.map (function
            | (Bnr (o, l, r), x) ->
                (l, (o, x, r))
            | _ -> assert false)
          bin_pats
      and pats_r =
        List.map (fun (l, (o, x, r)) ->
            (r, (o, l, x)))
      and patstop =
        List.map (fun (r, (o, l, x)) ->
            (Bnr (o, l, r), x))
      in
      let id_pairs = List.concat_map snd id_ops in
      let ids_l = List.map fst id_pairs
      and ids_r id_left =
        List.filter_map (fun (l, r) ->
            if l = id_left then Some r else None)
          id_pairs
      in
      (* match the left arm *)
      Action.mk_push ~sym
        (gen ids_l pats_l
         @@ fun lid pats ->
         (* then the right arm, considering
          * only the remaining possible
          * patterns and knowing that the
          * left arm was numbered 'lid' *)
          Action.mk_pop
            (gen (ids_r lid) (pats_r pats)
             @@ fun _rid pats ->
             (* continue with the parent *)
             k id_top (patstop pats)))
    with Stuck ->
      let atm_pats =
        let seen = states.(id_top).seen in
        List.filter (fun (pat, _) ->
            pattern_match pat seen) atm_pats
      in
      if atm_pats = [] then raise Stuck else
      let vars =
        List.filter_map (function
            | (Var (v, _), _) -> Some v
            | _ -> None) atm_pats |> setify
      in
      match vars with
      | [] -> k id_top atm_pats
      | [v] -> Action.mk_set v (k id_top atm_pats)
      | _ -> failwith "ambiguous var match"
  in
  (* generate a matcher for the rule *)
  let ids_top =
    Array.to_list states |>
    List.filter_map (fun {id; point = p; _} ->
        if List.exists ((=) (Top name)) p then
          Some id
        else None)
  in
  let rec filter_dups pats =
    match pats with
    | p :: pats ->
        if List.exists (pattern_match p) pats
        then filter_dups pats
        else p :: filter_dups pats
    | [] -> []
  in
  let pats_top =
    List.filter_map (fun r ->
        if r.name = name then
          Some r.pattern
        else None) rules |>
    filter_dups |>
    List.map (fun p -> (p, ()))
  in
  gen ids_top pats_top (fun _ pats ->
      assert (pats <> []);
      Action.stop)

type numberer =
  { atoms: (atomic_pattern * p state) list
  ; statemap: p state StateMap.t
  ; states: p state array
  ; mutable ops: op list
    (* memoizes the list of possible operations
     * according to the statemap *) }

let make_numberer sa am sm =
  { atoms = am
  ; states = sa
  ; statemap = sm
  ; ops = [] }

let atom_state n atm =
  List.assoc atm n.atoms