libqbe/tools/match.ml
2024-04-09 21:32:49 +02:00

331 lines
7.5 KiB
OCaml

type cls = Kw | Kl | Ks | Kd
type op_base =
| Oadd
| Osub
| Omul
type op = cls * op_base
let commutative = function
| (_, (Oadd | Omul)) -> true
| (_, _) -> false
let associative = function
| (_, (Oadd | Omul)) -> true
| (_, _) -> false
type atomic_pattern =
| Tmp
| AnyCon
| Con of int64
type pattern =
| Bnr of op * pattern * pattern
| Atm of atomic_pattern
| Var of string * atomic_pattern
let rec pattern_match p w =
match p with
| Var _ ->
failwith "variable not allowed"
| Atm (Tmp) ->
begin match w with
| Atm (Con _ | AnyCon) -> false
| _ -> true
end
| Atm (Con _) -> w = p
| Atm (AnyCon) ->
not (pattern_match (Atm Tmp) w)
| Bnr (o, pl, pr) ->
begin match w with
| Bnr (o', wl, wr) ->
o' = o &&
pattern_match pl wl &&
pattern_match pr wr
| _ -> false
end
type 'a cursor = (* a position inside a pattern *)
| Bnrl of op * 'a cursor * pattern
| Bnrr of op * pattern * 'a cursor
| Top of 'a
let rec fold_cursor c p =
match c with
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
| Top _ -> p
let peel p x =
let once out (p, c) =
match p with
| Var _ -> failwith "variable not allowed"
| Atm _ -> (p, c) :: out
| Bnr (o, pl, pr) ->
(pl, Bnrl (o, c, pr)) ::
(pr, Bnrr (o, pl, c)) :: out
in
let rec go l =
let l' = List.fold_left once [] l in
if List.length l' = List.length l
then l
else go l'
in go [(p, Top x)]
let fold_pairs l1 l2 ini f =
let rec go acc = function
| [] -> acc
| a :: l1' ->
go (List.fold_left
(fun acc b -> f (a, b) acc)
acc l2) l1'
in go ini l1
let iter_pairs l f =
fold_pairs l l () (fun x () -> f x)
type 'a state =
{ id: int
; seen: pattern
; point: ('a cursor) list }
let rec binops side {point; _} =
List.fold_left (fun res c ->
match c, side with
| Bnrl (o, c, r), `L -> ((o, c), r) :: res
| Bnrr (o, l, c), `R -> ((o, c), l) :: res
| _ -> res)
[] point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
compare a b) l |>
List.fold_left (fun (oo, l, res) (o', c) ->
match oo with
| None -> (Some o', [c], [])
| Some o when o = o' -> (oo, c :: l, res)
| Some o -> (Some o', [c], (o, l) :: res))
(None, [], []) |>
(function
| (None, _, _) -> []
| (Some o, l, res) -> (o, l) :: res)
let sort_uniq cmp l =
List.fast_sort cmp l |>
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
| Some e ->
if cmp e e' = 0
then (eo, l)
else (Some e', e :: l)
) (None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
let normalize (point: ('a cursor) list) =
sort_uniq compare point
let nextbnr tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
List.map fst
and o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
List.map fst
in
List.map (fun (o, l) ->
o,
{ id = 0
; seen = Bnr (o, s1.seen, s2.seen)
; point = normalize (l @ tmp)
}) (group_by_fst (o1 @ o2))
type p = string
module StateSet : sig
type set
val create: unit -> set
val add: set -> p state ->
[> `Added | `Found ] * p state
val iter: set -> (p state -> unit) -> unit
val elems: set -> (p state) list
end = struct
include Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
type set =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 1 }
let add set s =
(* delete the check later *)
assert (s.point = normalize s.point);
try
let id = find set.h s in
`Found, {s with id}
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
add set.h s id;
`Added, {s with id}
end
let iter set f =
let f s id = f {s with id} in
iter f set.h
let elems set =
let res = ref [] in
iter set (fun s -> res := s :: !res);
!res
end
type table_key =
| K of op * p state * p state
module StateMap = Map.Make(struct
type t = table_key
let compare ka kb =
match ka, kb with
| K (o, sl, sr), K (o', sl', sr') ->
compare (o, sl.id, sr.id)
(o', sl'.id, sr'.id)
end)
type rule =
{ name: string
; pattern: pattern
(* TODO access pattern *)
}
let generate_table rl =
let states = StateSet.create () in
(* initialize states *)
let ground =
List.fold_left
(fun ini r ->
peel r.pattern r.name @ ini)
[] rl |>
group_by_fst
in
let find x d l =
try List.assoc x l with Not_found -> d in
let tmp = find (Atm Tmp) [] ground in
let con = find (Atm AnyCon) [] ground in
let () =
List.iter (fun (seen, l) ->
let point =
if pattern_match (Atm Tmp) seen
then normalize (tmp @ l)
else normalize (con @ l)
in
let s = {id = 0; seen; point} in
let flag, _ = StateSet.add states s in
assert (flag = `Added)
) ground
in
(* setup loop state *)
let map = ref StateMap.empty in
let map_add k s' =
map := StateMap.add k s' !map
in
let flag = ref `Added in
let flagmerge = function
| `Added -> flag := `Added
| _ -> ()
in
(* iterate until fixpoint *)
while !flag = `Added do
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
nextbnr tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in
flagmerge flag';
map_add (K (o, sl, sr)) s';
));
done;
(StateSet.elems states, !map)
let intersperse x l =
let rec go left right out =
let out =
(List.rev left @ [x] @ right) ::
out in
match right with
| x :: right' ->
go (x :: left) right' out
| [] -> out
in go [] l []
let rec permute = function
| [] -> [[]]
| x :: l ->
List.concat (List.map
(intersperse x) (permute l))
(* build all binary trees with ordered
* leaves l *)
let rec bins build l =
let rec go l r out =
match r with
| [] -> out
| x :: r' ->
go (l @ [x]) r'
(fold_pairs
(bins build l)
(bins build r)
out (fun (l, r) out ->
build l r :: out))
in
match l with
| [] -> []
| [x] -> [x]
| x :: l -> go [x] l []
let products l ini f =
let rec go acc la = function
| [] -> f (List.rev la) acc
| xs :: l ->
List.fold_left (fun acc x ->
go acc (x :: la) l)
acc xs
in go ini [] l
(* combinatorial nuke... *)
let rec ac_equiv =
let rec alevel o = function
| Bnr (o', l, r) when o' = o ->
alevel o l @ alevel o r
| x -> [x]
in function
| Bnr (o, _, _) as p
when associative o ->
products
(List.map ac_equiv (alevel o p)) []
(fun choice out ->
List.map
(bins (fun l r -> Bnr (o, l, r)))
(if commutative o
then permute choice
else [choice]) |>
List.concat |>
(fun l -> List.rev_append l out))
| Bnr (o, l, r)
when commutative o ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) ::
Bnr (o, r, l) :: out)
| Bnr (o, l, r) ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) :: out)
| x -> [x]