modulo ac matching and more tests
This commit is contained in:
parent
24d1324424
commit
a374da3c2e
2 changed files with 390 additions and 65 deletions
353
tools/match.ml
353
tools/match.ml
|
@ -5,26 +5,36 @@ type op_base =
|
|||
| Omul
|
||||
type op = cls * op_base
|
||||
|
||||
let commutative = function
|
||||
| (_, (Oadd | Omul)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
let associative = function
|
||||
| (_, (Oadd | Omul)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
type atomic_pattern =
|
||||
| Any
|
||||
| Tmp
|
||||
| AnyCon
|
||||
| Con of int64
|
||||
|
||||
type pattern =
|
||||
| Bnr of op * pattern * pattern
|
||||
| Unr of op * pattern
|
||||
| Atm of atomic_pattern
|
||||
| Var of string * atomic_pattern
|
||||
|
||||
let rec pattern_match p w =
|
||||
match p with
|
||||
| Atm (Any) -> true
|
||||
| Atm (Con _) -> w = p
|
||||
| Unr (o, pa) ->
|
||||
| Var _ ->
|
||||
failwith "variable not allowed"
|
||||
| Atm (Tmp) ->
|
||||
begin match w with
|
||||
| Unr (o', wa) ->
|
||||
o' = o &&
|
||||
pattern_match pa wa
|
||||
| _ -> false
|
||||
| Atm (Con _ | AnyCon) -> false
|
||||
| _ -> true
|
||||
end
|
||||
| Atm (Con _) -> w = p
|
||||
| Atm (AnyCon) ->
|
||||
not (pattern_match (Atm Tmp) w)
|
||||
| Bnr (o, pl, pr) ->
|
||||
begin match w with
|
||||
| Bnr (o', wl, wr) ->
|
||||
|
@ -34,75 +44,288 @@ let rec pattern_match p w =
|
|||
| _ -> false
|
||||
end
|
||||
|
||||
let test_pattern_match =
|
||||
let pm = pattern_match
|
||||
and nm = fun x y -> not (pattern_match x y)
|
||||
and o = (Kw, Oadd) in
|
||||
begin
|
||||
assert (pm (Atm Any) (Atm (Con 42L)));
|
||||
assert (pm (Atm Any) (Unr (o, Atm Any)));
|
||||
assert (nm (Atm (Con 42L)) (Atm Any));
|
||||
assert (pm (Unr (o, Atm Any))
|
||||
(Unr (o, Atm (Con 42L))));
|
||||
assert (nm (Unr (o, Atm Any))
|
||||
(Unr ((Kl, Oadd), Atm (Con 42L))));
|
||||
assert (nm (Unr (o, Atm Any))
|
||||
(Bnr (o, Atm (Con 42L), Atm Any)));
|
||||
end
|
||||
|
||||
type cursor = (* a position inside a pattern *)
|
||||
| Bnrl of op * cursor * pattern
|
||||
| Bnrr of op * pattern * cursor
|
||||
| Unra of op * cursor
|
||||
| Top
|
||||
type 'a cursor = (* a position inside a pattern *)
|
||||
| Bnrl of op * 'a cursor * pattern
|
||||
| Bnrr of op * pattern * 'a cursor
|
||||
| Top of 'a
|
||||
|
||||
let rec fold_cursor c p =
|
||||
match c with
|
||||
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
|
||||
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
|
||||
| Unra (o, c') -> fold_cursor c' (Unr (o, p))
|
||||
| Top -> p
|
||||
| Top _ -> p
|
||||
|
||||
let peel p =
|
||||
let once out (c, p) =
|
||||
let peel p x =
|
||||
let once out (p, c) =
|
||||
match p with
|
||||
| Atm _ -> (c, p) :: out
|
||||
| Unr (o, pa) ->
|
||||
(Unra (o, c), pa) :: out
|
||||
| Var _ -> failwith "variable not allowed"
|
||||
| Atm _ -> (p, c) :: out
|
||||
| Bnr (o, pl, pr) ->
|
||||
(Bnrl (o, c, pr), pl) ::
|
||||
(Bnrr (o, pl, c), pr) :: out
|
||||
(pl, Bnrl (o, c, pr)) ::
|
||||
(pr, Bnrr (o, pl, c)) :: out
|
||||
in
|
||||
let rec go l =
|
||||
let l' = List.fold_left once [] l in
|
||||
if List.length l' = List.length l
|
||||
then l
|
||||
else go l'
|
||||
in go [(Top, p)]
|
||||
in go [(p, Top x)]
|
||||
|
||||
let test_peel =
|
||||
let o = Kw, Oadd in
|
||||
let p = Bnr (o, Bnr (o, Atm Any, Atm Any),
|
||||
Atm (Con 42L)) in
|
||||
let l = peel p in
|
||||
let () = assert (List.length l = 3) in
|
||||
let atomic_p (_, p) =
|
||||
match p with Atm _ -> true | _ -> false in
|
||||
let () = assert (List.for_all atomic_p l) in
|
||||
let l = List.map (fun (c, p) -> fold_cursor c p) l in
|
||||
let () = assert (List.for_all ((=) p) l) in
|
||||
()
|
||||
let fold_pairs l1 l2 ini f =
|
||||
let rec go acc = function
|
||||
| [] -> acc
|
||||
| a :: l1' ->
|
||||
go (List.fold_left
|
||||
(fun acc b -> f (a, b) acc)
|
||||
acc l2) l1'
|
||||
in go ini l1
|
||||
|
||||
(* we want to compute all the configurations we could
|
||||
* possibly be in when processing a block of instructions;
|
||||
* to do so, we start with all the possible cursors for
|
||||
* the list of patterns we are given, this will be our
|
||||
* main "initial state"; each constant (used in the
|
||||
* patterns) also generates a state of its own
|
||||
*
|
||||
* to create new states we can take pairs of states, and
|
||||
* combine them with binary operations, we keep the
|
||||
* result if it is non-trivial (non-empty) and new (we
|
||||
* have not seen this cursor combination yet); we can
|
||||
* also do the same with unary operations
|
||||
* *)
|
||||
let iter_pairs l f =
|
||||
fold_pairs l l () (fun x () -> f x)
|
||||
|
||||
type 'a state =
|
||||
{ id: int
|
||||
; seen: pattern
|
||||
; point: ('a cursor) list }
|
||||
|
||||
let rec binops side {point; _} =
|
||||
List.fold_left (fun res c ->
|
||||
match c, side with
|
||||
| Bnrl (o, c, r), `L -> ((o, c), r) :: res
|
||||
| Bnrr (o, l, c), `R -> ((o, c), l) :: res
|
||||
| _ -> res)
|
||||
[] point
|
||||
|
||||
let group_by_fst l =
|
||||
List.fast_sort (fun (a, _) (b, _) ->
|
||||
compare a b) l |>
|
||||
List.fold_left (fun (oo, l, res) (o', c) ->
|
||||
match oo with
|
||||
| None -> (Some o', [c], [])
|
||||
| Some o when o = o' -> (oo, c :: l, res)
|
||||
| Some o -> (Some o', [c], (o, l) :: res))
|
||||
(None, [], []) |>
|
||||
(function
|
||||
| (None, _, _) -> []
|
||||
| (Some o, l, res) -> (o, l) :: res)
|
||||
|
||||
let sort_uniq cmp l =
|
||||
List.fast_sort cmp l |>
|
||||
List.fold_left (fun (eo, l) e' ->
|
||||
match eo with
|
||||
| None -> (Some e', l)
|
||||
| Some e ->
|
||||
if cmp e e' = 0
|
||||
then (eo, l)
|
||||
else (Some e', e :: l)
|
||||
) (None, []) |>
|
||||
(function
|
||||
| (None, _) -> []
|
||||
| (Some e, l) -> List.rev (e :: l))
|
||||
|
||||
let normalize (point: ('a cursor) list) =
|
||||
sort_uniq compare point
|
||||
|
||||
let nextbnr tmp s1 s2 =
|
||||
let pm w (_, p) = pattern_match p w in
|
||||
let o1 = binops `L s1 |>
|
||||
List.filter (pm s2.seen) |>
|
||||
List.map fst
|
||||
and o2 = binops `R s2 |>
|
||||
List.filter (pm s1.seen) |>
|
||||
List.map fst
|
||||
in
|
||||
List.map (fun (o, l) ->
|
||||
o,
|
||||
{ id = 0
|
||||
; seen = Bnr (o, s1.seen, s2.seen)
|
||||
; point = normalize (l @ tmp)
|
||||
}) (group_by_fst (o1 @ o2))
|
||||
|
||||
type p = string
|
||||
|
||||
module StateSet : sig
|
||||
type set
|
||||
val create: unit -> set
|
||||
val add: set -> p state ->
|
||||
[> `Added | `Found ] * p state
|
||||
val iter: set -> (p state -> unit) -> unit
|
||||
val elems: set -> (p state) list
|
||||
end = struct
|
||||
include Hashtbl.Make(struct
|
||||
type t = p state
|
||||
let equal s1 s2 = s1.point = s2.point
|
||||
let hash s = Hashtbl.hash s.point
|
||||
end)
|
||||
type set =
|
||||
{ h: int t
|
||||
; mutable next_id: int }
|
||||
let create () =
|
||||
{ h = create 500; next_id = 1 }
|
||||
let add set s =
|
||||
(* delete the check later *)
|
||||
assert (s.point = normalize s.point);
|
||||
try
|
||||
let id = find set.h s in
|
||||
`Found, {s with id}
|
||||
with Not_found -> begin
|
||||
let id = set.next_id in
|
||||
set.next_id <- id + 1;
|
||||
add set.h s id;
|
||||
`Added, {s with id}
|
||||
end
|
||||
let iter set f =
|
||||
let f s id = f {s with id} in
|
||||
iter f set.h
|
||||
let elems set =
|
||||
let res = ref [] in
|
||||
iter set (fun s -> res := s :: !res);
|
||||
!res
|
||||
end
|
||||
|
||||
type table_key =
|
||||
| K of op * p state * p state
|
||||
|
||||
module StateMap = Map.Make(struct
|
||||
type t = table_key
|
||||
let compare ka kb =
|
||||
match ka, kb with
|
||||
| K (o, sl, sr), K (o', sl', sr') ->
|
||||
compare (o, sl.id, sr.id)
|
||||
(o', sl'.id, sr'.id)
|
||||
end)
|
||||
|
||||
type rule =
|
||||
{ name: string
|
||||
; pattern: pattern
|
||||
(* TODO access pattern *)
|
||||
}
|
||||
|
||||
let generate_table rl =
|
||||
let states = StateSet.create () in
|
||||
(* initialize states *)
|
||||
let ground =
|
||||
List.fold_left
|
||||
(fun ini r ->
|
||||
peel r.pattern r.name @ ini)
|
||||
[] rl |>
|
||||
group_by_fst
|
||||
in
|
||||
let find x d l =
|
||||
try List.assoc x l with Not_found -> d in
|
||||
let tmp = find (Atm Tmp) [] ground in
|
||||
let con = find (Atm AnyCon) [] ground in
|
||||
let () =
|
||||
List.iter (fun (seen, l) ->
|
||||
let point =
|
||||
if pattern_match (Atm Tmp) seen
|
||||
then normalize (tmp @ l)
|
||||
else normalize (con @ l)
|
||||
in
|
||||
let s = {id = 0; seen; point} in
|
||||
let flag, _ = StateSet.add states s in
|
||||
assert (flag = `Added)
|
||||
) ground
|
||||
in
|
||||
(* setup loop state *)
|
||||
let map = ref StateMap.empty in
|
||||
let map_add k s' =
|
||||
map := StateMap.add k s' !map
|
||||
in
|
||||
let flag = ref `Added in
|
||||
let flagmerge = function
|
||||
| `Added -> flag := `Added
|
||||
| _ -> ()
|
||||
in
|
||||
(* iterate until fixpoint *)
|
||||
while !flag = `Added do
|
||||
flag := `Stop;
|
||||
let statel = StateSet.elems states in
|
||||
iter_pairs statel (fun (sl, sr) ->
|
||||
nextbnr tmp sl sr |>
|
||||
List.iter (fun (o, s') ->
|
||||
let flag', s' =
|
||||
StateSet.add states s' in
|
||||
flagmerge flag';
|
||||
map_add (K (o, sl, sr)) s';
|
||||
));
|
||||
done;
|
||||
(StateSet.elems states, !map)
|
||||
|
||||
let intersperse x l =
|
||||
let rec go left right out =
|
||||
let out =
|
||||
(List.rev left @ [x] @ right) ::
|
||||
out in
|
||||
match right with
|
||||
| x :: right' ->
|
||||
go (x :: left) right' out
|
||||
| [] -> out
|
||||
in go [] l []
|
||||
|
||||
let rec permute = function
|
||||
| [] -> [[]]
|
||||
| x :: l ->
|
||||
List.concat (List.map
|
||||
(intersperse x) (permute l))
|
||||
|
||||
(* build all binary trees with ordered
|
||||
* leaves l *)
|
||||
let rec bins build l =
|
||||
let rec go l r out =
|
||||
match r with
|
||||
| [] -> out
|
||||
| x :: r' ->
|
||||
go (l @ [x]) r'
|
||||
(fold_pairs
|
||||
(bins build l)
|
||||
(bins build r)
|
||||
out (fun (l, r) out ->
|
||||
build l r :: out))
|
||||
in
|
||||
match l with
|
||||
| [] -> []
|
||||
| [x] -> [x]
|
||||
| x :: l -> go [x] l []
|
||||
|
||||
let products l ini f =
|
||||
let rec go acc la = function
|
||||
| [] -> f (List.rev la) acc
|
||||
| xs :: l ->
|
||||
List.fold_left (fun acc x ->
|
||||
go acc (x :: la) l)
|
||||
acc xs
|
||||
in go ini [] l
|
||||
|
||||
(* combinatorial nuke... *)
|
||||
let rec ac_equiv =
|
||||
let rec alevel o = function
|
||||
| Bnr (o', l, r) when o' = o ->
|
||||
alevel o l @ alevel o r
|
||||
| x -> [x]
|
||||
in function
|
||||
| Bnr (o, _, _) as p
|
||||
when associative o ->
|
||||
products
|
||||
(List.map ac_equiv (alevel o p)) []
|
||||
(fun choice out ->
|
||||
List.map
|
||||
(bins (fun l r -> Bnr (o, l, r)))
|
||||
(if commutative o
|
||||
then permute choice
|
||||
else [choice]) |>
|
||||
List.concat |>
|
||||
(fun l -> List.rev_append l out))
|
||||
| Bnr (o, l, r)
|
||||
when commutative o ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) ::
|
||||
Bnr (o, r, l) :: out)
|
||||
| Bnr (o, l, r) ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) :: out)
|
||||
| x -> [x]
|
||||
|
|
102
tools/match_test.ml
Normal file
102
tools/match_test.ml
Normal file
|
@ -0,0 +1,102 @@
|
|||
#use "match.ml"
|
||||
|
||||
let test_pattern_match =
|
||||
let pm = pattern_match
|
||||
and nm = fun x y -> not (pattern_match x y) in
|
||||
begin
|
||||
assert (nm (Atm Tmp) (Atm (Con 42L)));
|
||||
assert (pm (Atm AnyCon) (Atm (Con 42L)));
|
||||
assert (nm (Atm (Con 42L)) (Atm AnyCon));
|
||||
assert (nm (Atm (Con 42L)) (Atm Tmp));
|
||||
end
|
||||
|
||||
let test_peel =
|
||||
let o = Kw, Oadd in
|
||||
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
|
||||
Atm (Con 42L)) in
|
||||
let l = peel p () in
|
||||
let () = assert (List.length l = 3) in
|
||||
let atomic_p (p, _) =
|
||||
match p with Atm _ -> true | _ -> false in
|
||||
let () = assert (List.for_all atomic_p l) in
|
||||
let l = List.map (fun (p, c) -> fold_cursor c p) l in
|
||||
let () = assert (List.for_all ((=) p) l) in
|
||||
()
|
||||
|
||||
let test_fold_pairs =
|
||||
let l = [1; 2; 3; 4; 5] in
|
||||
let p = fold_pairs l l [] (fun a b -> a :: b) in
|
||||
let () = assert (List.length p = 25) in
|
||||
let p = sort_uniq compare p in
|
||||
let () = assert (List.length p = 25) in
|
||||
()
|
||||
|
||||
(* test pattern & state *)
|
||||
let tp =
|
||||
let o = Kw, Oadd in
|
||||
Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
|
||||
Atm (Con 0L))
|
||||
let ts =
|
||||
{ id = 0
|
||||
; seen = Atm Tmp
|
||||
; point =
|
||||
List.map snd
|
||||
(List.filter (fun (p, _) -> p = Atm Tmp)
|
||||
(peel tp ()))
|
||||
}
|
||||
|
||||
let print_sm =
|
||||
let op_str (k, o) =
|
||||
Printf.sprintf "%s%s"
|
||||
(match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul")
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
in
|
||||
StateMap.iter (fun k s' ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
Printf.printf
|
||||
"(%s %d %d) -> %d\n"
|
||||
(op_str o)
|
||||
sl.id sr.id s'.id
|
||||
)
|
||||
|
||||
let address_rules =
|
||||
let oa = Kl, Oadd in
|
||||
let om = Kl, Omul in
|
||||
let rule name pattern = { name; pattern; } in
|
||||
(* o + b *)
|
||||
[ rule "ob1" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
||||
; rule "ob2" (Bnr (oa, Atm AnyCon, Atm Tmp))
|
||||
|
||||
(* b + s * i *)
|
||||
; rule "bs1" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
|
||||
; rule "bs2" (Bnr (oa, Atm Tmp, Bnr (om, Atm Tmp, Atm AnyCon)))
|
||||
; rule "bs3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm Tmp))
|
||||
; rule "bs4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm Tmp))
|
||||
|
||||
(* o + s * i *)
|
||||
; rule "os1" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
|
||||
; rule "os2" (Bnr (oa, Atm AnyCon, Bnr (om, Atm Tmp, Atm AnyCon)))
|
||||
; rule "os3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm AnyCon))
|
||||
; rule "os4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm AnyCon))
|
||||
]
|
||||
|
||||
(*
|
||||
let sl, sm = generate_table address_rules
|
||||
let s n = List.find (fun {id; _} -> id = n) sl
|
||||
let () = print_sm sm
|
||||
*)
|
||||
|
||||
let tp0 =
|
||||
let o = Kw, Oadd in
|
||||
Bnr (o, Atm Tmp, Atm (Con 0L))
|
||||
let tp1 =
|
||||
let o = Kw, Oadd in
|
||||
Bnr (o, tp0, Atm (Con 1L))
|
Loading…
Add table
Reference in a new issue