From a374da3c2e205bb8c8548c1a63f186b6d9188e9c Mon Sep 17 00:00:00 2001 From: Quentin Carbonneaux Date: Thu, 14 Dec 2017 22:35:30 +0100 Subject: [PATCH] modulo ac matching and more tests --- tools/match.ml | 353 ++++++++++++++++++++++++++++++++++++-------- tools/match_test.ml | 102 +++++++++++++ 2 files changed, 390 insertions(+), 65 deletions(-) create mode 100644 tools/match_test.ml diff --git a/tools/match.ml b/tools/match.ml index 0eaa244..4aeeae0 100644 --- a/tools/match.ml +++ b/tools/match.ml @@ -5,26 +5,36 @@ type op_base = | Omul type op = cls * op_base +let commutative = function + | (_, (Oadd | Omul)) -> true + | (_, _) -> false + +let associative = function + | (_, (Oadd | Omul)) -> true + | (_, _) -> false + type atomic_pattern = - | Any + | Tmp + | AnyCon | Con of int64 type pattern = | Bnr of op * pattern * pattern - | Unr of op * pattern | Atm of atomic_pattern + | Var of string * atomic_pattern let rec pattern_match p w = match p with - | Atm (Any) -> true - | Atm (Con _) -> w = p - | Unr (o, pa) -> + | Var _ -> + failwith "variable not allowed" + | Atm (Tmp) -> begin match w with - | Unr (o', wa) -> - o' = o && - pattern_match pa wa - | _ -> false + | Atm (Con _ | AnyCon) -> false + | _ -> true end + | Atm (Con _) -> w = p + | Atm (AnyCon) -> + not (pattern_match (Atm Tmp) w) | Bnr (o, pl, pr) -> begin match w with | Bnr (o', wl, wr) -> @@ -34,75 +44,288 @@ let rec pattern_match p w = | _ -> false end -let test_pattern_match = - let pm = pattern_match - and nm = fun x y -> not (pattern_match x y) - and o = (Kw, Oadd) in - begin - assert (pm (Atm Any) (Atm (Con 42L))); - assert (pm (Atm Any) (Unr (o, Atm Any))); - assert (nm (Atm (Con 42L)) (Atm Any)); - assert (pm (Unr (o, Atm Any)) - (Unr (o, Atm (Con 42L)))); - assert (nm (Unr (o, Atm Any)) - (Unr ((Kl, Oadd), Atm (Con 42L)))); - assert (nm (Unr (o, Atm Any)) - (Bnr (o, Atm (Con 42L), Atm Any))); - end - -type cursor = (* a position inside a pattern *) - | Bnrl of op * cursor * pattern - | Bnrr of op * pattern * cursor - | Unra of op * cursor - | Top +type 'a cursor = (* a position inside a pattern *) + | Bnrl of op * 'a cursor * pattern + | Bnrr of op * pattern * 'a cursor + | Top of 'a let rec fold_cursor c p = match c with | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p')) | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p)) - | Unra (o, c') -> fold_cursor c' (Unr (o, p)) - | Top -> p + | Top _ -> p -let peel p = - let once out (c, p) = +let peel p x = + let once out (p, c) = match p with - | Atm _ -> (c, p) :: out - | Unr (o, pa) -> - (Unra (o, c), pa) :: out + | Var _ -> failwith "variable not allowed" + | Atm _ -> (p, c) :: out | Bnr (o, pl, pr) -> - (Bnrl (o, c, pr), pl) :: - (Bnrr (o, pl, c), pr) :: out + (pl, Bnrl (o, c, pr)) :: + (pr, Bnrr (o, pl, c)) :: out in let rec go l = let l' = List.fold_left once [] l in if List.length l' = List.length l then l else go l' - in go [(Top, p)] + in go [(p, Top x)] -let test_peel = - let o = Kw, Oadd in - let p = Bnr (o, Bnr (o, Atm Any, Atm Any), - Atm (Con 42L)) in - let l = peel p in - let () = assert (List.length l = 3) in - let atomic_p (_, p) = - match p with Atm _ -> true | _ -> false in - let () = assert (List.for_all atomic_p l) in - let l = List.map (fun (c, p) -> fold_cursor c p) l in - let () = assert (List.for_all ((=) p) l) in - () +let fold_pairs l1 l2 ini f = + let rec go acc = function + | [] -> acc + | a :: l1' -> + go (List.fold_left + (fun acc b -> f (a, b) acc) + acc l2) l1' + in go ini l1 -(* we want to compute all the configurations we could - * possibly be in when processing a block of instructions; - * to do so, we start with all the possible cursors for - * the list of patterns we are given, this will be our - * main "initial state"; each constant (used in the - * patterns) also generates a state of its own - * - * to create new states we can take pairs of states, and - * combine them with binary operations, we keep the - * result if it is non-trivial (non-empty) and new (we - * have not seen this cursor combination yet); we can - * also do the same with unary operations - * *) +let iter_pairs l f = + fold_pairs l l () (fun x () -> f x) + +type 'a state = + { id: int + ; seen: pattern + ; point: ('a cursor) list } + +let rec binops side {point; _} = + List.fold_left (fun res c -> + match c, side with + | Bnrl (o, c, r), `L -> ((o, c), r) :: res + | Bnrr (o, l, c), `R -> ((o, c), l) :: res + | _ -> res) + [] point + +let group_by_fst l = + List.fast_sort (fun (a, _) (b, _) -> + compare a b) l |> + List.fold_left (fun (oo, l, res) (o', c) -> + match oo with + | None -> (Some o', [c], []) + | Some o when o = o' -> (oo, c :: l, res) + | Some o -> (Some o', [c], (o, l) :: res)) + (None, [], []) |> + (function + | (None, _, _) -> [] + | (Some o, l, res) -> (o, l) :: res) + +let sort_uniq cmp l = + List.fast_sort cmp l |> + List.fold_left (fun (eo, l) e' -> + match eo with + | None -> (Some e', l) + | Some e -> + if cmp e e' = 0 + then (eo, l) + else (Some e', e :: l) + ) (None, []) |> + (function + | (None, _) -> [] + | (Some e, l) -> List.rev (e :: l)) + +let normalize (point: ('a cursor) list) = + sort_uniq compare point + +let nextbnr tmp s1 s2 = + let pm w (_, p) = pattern_match p w in + let o1 = binops `L s1 |> + List.filter (pm s2.seen) |> + List.map fst + and o2 = binops `R s2 |> + List.filter (pm s1.seen) |> + List.map fst + in + List.map (fun (o, l) -> + o, + { id = 0 + ; seen = Bnr (o, s1.seen, s2.seen) + ; point = normalize (l @ tmp) + }) (group_by_fst (o1 @ o2)) + +type p = string + +module StateSet : sig + type set + val create: unit -> set + val add: set -> p state -> + [> `Added | `Found ] * p state + val iter: set -> (p state -> unit) -> unit + val elems: set -> (p state) list +end = struct + include Hashtbl.Make(struct + type t = p state + let equal s1 s2 = s1.point = s2.point + let hash s = Hashtbl.hash s.point + end) + type set = + { h: int t + ; mutable next_id: int } + let create () = + { h = create 500; next_id = 1 } + let add set s = + (* delete the check later *) + assert (s.point = normalize s.point); + try + let id = find set.h s in + `Found, {s with id} + with Not_found -> begin + let id = set.next_id in + set.next_id <- id + 1; + add set.h s id; + `Added, {s with id} + end + let iter set f = + let f s id = f {s with id} in + iter f set.h + let elems set = + let res = ref [] in + iter set (fun s -> res := s :: !res); + !res +end + +type table_key = + | K of op * p state * p state + +module StateMap = Map.Make(struct + type t = table_key + let compare ka kb = + match ka, kb with + | K (o, sl, sr), K (o', sl', sr') -> + compare (o, sl.id, sr.id) + (o', sl'.id, sr'.id) +end) + +type rule = + { name: string + ; pattern: pattern + (* TODO access pattern *) + } + +let generate_table rl = + let states = StateSet.create () in + (* initialize states *) + let ground = + List.fold_left + (fun ini r -> + peel r.pattern r.name @ ini) + [] rl |> + group_by_fst + in + let find x d l = + try List.assoc x l with Not_found -> d in + let tmp = find (Atm Tmp) [] ground in + let con = find (Atm AnyCon) [] ground in + let () = + List.iter (fun (seen, l) -> + let point = + if pattern_match (Atm Tmp) seen + then normalize (tmp @ l) + else normalize (con @ l) + in + let s = {id = 0; seen; point} in + let flag, _ = StateSet.add states s in + assert (flag = `Added) + ) ground + in + (* setup loop state *) + let map = ref StateMap.empty in + let map_add k s' = + map := StateMap.add k s' !map + in + let flag = ref `Added in + let flagmerge = function + | `Added -> flag := `Added + | _ -> () + in + (* iterate until fixpoint *) + while !flag = `Added do + flag := `Stop; + let statel = StateSet.elems states in + iter_pairs statel (fun (sl, sr) -> + nextbnr tmp sl sr |> + List.iter (fun (o, s') -> + let flag', s' = + StateSet.add states s' in + flagmerge flag'; + map_add (K (o, sl, sr)) s'; + )); + done; + (StateSet.elems states, !map) + +let intersperse x l = + let rec go left right out = + let out = + (List.rev left @ [x] @ right) :: + out in + match right with + | x :: right' -> + go (x :: left) right' out + | [] -> out + in go [] l [] + +let rec permute = function + | [] -> [[]] + | x :: l -> + List.concat (List.map + (intersperse x) (permute l)) + +(* build all binary trees with ordered + * leaves l *) +let rec bins build l = + let rec go l r out = + match r with + | [] -> out + | x :: r' -> + go (l @ [x]) r' + (fold_pairs + (bins build l) + (bins build r) + out (fun (l, r) out -> + build l r :: out)) + in + match l with + | [] -> [] + | [x] -> [x] + | x :: l -> go [x] l [] + +let products l ini f = + let rec go acc la = function + | [] -> f (List.rev la) acc + | xs :: l -> + List.fold_left (fun acc x -> + go acc (x :: la) l) + acc xs + in go ini [] l + +(* combinatorial nuke... *) +let rec ac_equiv = + let rec alevel o = function + | Bnr (o', l, r) when o' = o -> + alevel o l @ alevel o r + | x -> [x] + in function + | Bnr (o, _, _) as p + when associative o -> + products + (List.map ac_equiv (alevel o p)) [] + (fun choice out -> + List.map + (bins (fun l r -> Bnr (o, l, r))) + (if commutative o + then permute choice + else [choice]) |> + List.concat |> + (fun l -> List.rev_append l out)) + | Bnr (o, l, r) + when commutative o -> + fold_pairs + (ac_equiv l) (ac_equiv r) [] + (fun (l, r) out -> + Bnr (o, l, r) :: + Bnr (o, r, l) :: out) + | Bnr (o, l, r) -> + fold_pairs + (ac_equiv l) (ac_equiv r) [] + (fun (l, r) out -> + Bnr (o, l, r) :: out) + | x -> [x] diff --git a/tools/match_test.ml b/tools/match_test.ml new file mode 100644 index 0000000..75e2005 --- /dev/null +++ b/tools/match_test.ml @@ -0,0 +1,102 @@ +#use "match.ml" + +let test_pattern_match = + let pm = pattern_match + and nm = fun x y -> not (pattern_match x y) in + begin + assert (nm (Atm Tmp) (Atm (Con 42L))); + assert (pm (Atm AnyCon) (Atm (Con 42L))); + assert (nm (Atm (Con 42L)) (Atm AnyCon)); + assert (nm (Atm (Con 42L)) (Atm Tmp)); + end + +let test_peel = + let o = Kw, Oadd in + let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp), + Atm (Con 42L)) in + let l = peel p () in + let () = assert (List.length l = 3) in + let atomic_p (p, _) = + match p with Atm _ -> true | _ -> false in + let () = assert (List.for_all atomic_p l) in + let l = List.map (fun (p, c) -> fold_cursor c p) l in + let () = assert (List.for_all ((=) p) l) in + () + +let test_fold_pairs = + let l = [1; 2; 3; 4; 5] in + let p = fold_pairs l l [] (fun a b -> a :: b) in + let () = assert (List.length p = 25) in + let p = sort_uniq compare p in + let () = assert (List.length p = 25) in + () + +(* test pattern & state *) +let tp = + let o = Kw, Oadd in + Bnr (o, Bnr (o, Atm Tmp, Atm Tmp), + Atm (Con 0L)) +let ts = + { id = 0 + ; seen = Atm Tmp + ; point = + List.map snd + (List.filter (fun (p, _) -> p = Atm Tmp) + (peel tp ())) + } + +let print_sm = + let op_str (k, o) = + Printf.sprintf "%s%s" + (match o with + | Oadd -> "add" + | Osub -> "sub" + | Omul -> "mul") + (match k with + | Kw -> "w" + | Kl -> "l" + | Ks -> "s" + | Kd -> "d") + in + StateMap.iter (fun k s' -> + match k with + | K (o, sl, sr) -> + Printf.printf + "(%s %d %d) -> %d\n" + (op_str o) + sl.id sr.id s'.id + ) + +let address_rules = + let oa = Kl, Oadd in + let om = Kl, Omul in + let rule name pattern = { name; pattern; } in + (* o + b *) + [ rule "ob1" (Bnr (oa, Atm Tmp, Atm AnyCon)) + ; rule "ob2" (Bnr (oa, Atm AnyCon, Atm Tmp)) + + (* b + s * i *) + ; rule "bs1" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp))) + ; rule "bs2" (Bnr (oa, Atm Tmp, Bnr (om, Atm Tmp, Atm AnyCon))) + ; rule "bs3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm Tmp)) + ; rule "bs4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm Tmp)) + + (* o + s * i *) + ; rule "os1" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp))) + ; rule "os2" (Bnr (oa, Atm AnyCon, Bnr (om, Atm Tmp, Atm AnyCon))) + ; rule "os3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm AnyCon)) + ; rule "os4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm AnyCon)) + ] + +(* +let sl, sm = generate_table address_rules +let s n = List.find (fun {id; _} -> id = n) sl +let () = print_sm sm +*) + +let tp0 = + let o = Kw, Oadd in + Bnr (o, Atm Tmp, Atm (Con 0L)) +let tp1 = + let o = Kw, Oadd in + Bnr (o, tp0, Atm (Con 1L))