fuse ac rules in ins-tree matching
The initial plan was to have one matcher per ac-variant, but that leads to way too much generated code. Instead, we can fuse ac variants of the rules and have a smarter matching algorithm to recover bound variables.
This commit is contained in:
parent
8a5e1c3a23
commit
56e2263ca4
2 changed files with 86 additions and 72 deletions
|
@ -23,11 +23,32 @@ type pattern =
|
|||
| Atm of atomic_pattern
|
||||
| Var of string * atomic_pattern
|
||||
|
||||
let show_op (k, o) =
|
||||
(match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul") ^
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
|
||||
let rec show_pattern p =
|
||||
match p with
|
||||
| Var _ -> failwith "variable not allowed"
|
||||
| Atm Tmp -> "%"
|
||||
| Atm AnyCon -> "$"
|
||||
| Atm (Con n) -> Int64.to_string n
|
||||
| Bnr (o, pl, pr) ->
|
||||
"(" ^ show_op o ^
|
||||
" " ^ show_pattern pl ^
|
||||
" " ^ show_pattern pr ^ ")"
|
||||
|
||||
let rec pattern_match p w =
|
||||
match p with
|
||||
| Var _ ->
|
||||
failwith "variable not allowed"
|
||||
| Atm (Tmp) ->
|
||||
| Var _ -> failwith "variable not allowed"
|
||||
| Atm Tmp ->
|
||||
begin match w with
|
||||
| Atm (Con _ | AnyCon) -> false
|
||||
| _ -> true
|
||||
|
@ -89,12 +110,12 @@ type 'a state =
|
|||
; point: ('a cursor) list }
|
||||
|
||||
let rec binops side {point; _} =
|
||||
List.fold_left (fun res c ->
|
||||
List.filter_map (fun c ->
|
||||
match c, side with
|
||||
| Bnrl (o, c, r), `L -> ((o, c), r) :: res
|
||||
| Bnrr (o, l, c), `R -> ((o, c), l) :: res
|
||||
| _ -> res)
|
||||
[] point
|
||||
| Bnrl (o, c, r), `L -> Some ((o, c), r)
|
||||
| Bnrr (o, l, c), `R -> Some ((o, c), l)
|
||||
| _ -> None)
|
||||
point
|
||||
|
||||
let group_by_fst l =
|
||||
List.fast_sort (fun (a, _) (b, _) ->
|
||||
|
@ -114,11 +135,9 @@ let sort_uniq cmp l =
|
|||
List.fold_left (fun (eo, l) e' ->
|
||||
match eo with
|
||||
| None -> (Some e', l)
|
||||
| Some e ->
|
||||
if cmp e e' = 0
|
||||
then (eo, l)
|
||||
else (Some e', e :: l)
|
||||
) (None, []) |>
|
||||
| Some e when cmp e e' = 0 -> (eo, l)
|
||||
| Some e -> (Some e', e :: l))
|
||||
(None, []) |>
|
||||
(function
|
||||
| (None, _) -> []
|
||||
| (Some e, l) -> List.rev (e :: l))
|
||||
|
@ -126,15 +145,14 @@ let sort_uniq cmp l =
|
|||
let normalize (point: ('a cursor) list) =
|
||||
sort_uniq compare point
|
||||
|
||||
let nextbnr tmp s1 s2 =
|
||||
let next_binary tmp s1 s2 =
|
||||
let pm w (_, p) = pattern_match p w in
|
||||
let o1 = binops `L s1 |>
|
||||
List.filter (pm s2.seen) |>
|
||||
List.map fst
|
||||
and o2 = binops `R s2 |>
|
||||
List.map fst in
|
||||
let o2 = binops `R s2 |>
|
||||
List.filter (pm s1.seen) |>
|
||||
List.map fst
|
||||
in
|
||||
List.map fst in
|
||||
List.map (fun (o, l) ->
|
||||
o,
|
||||
{ id = 0
|
||||
|
@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
|
|||
type p = string
|
||||
|
||||
module StateSet : sig
|
||||
type set
|
||||
val create: unit -> set
|
||||
val add: set -> p state ->
|
||||
type t
|
||||
val create: unit -> t
|
||||
val add: t -> p state ->
|
||||
[> `Added | `Found ] * p state
|
||||
val iter: set -> (p state -> unit) -> unit
|
||||
val elems: set -> (p state) list
|
||||
val iter: t -> (p state -> unit) -> unit
|
||||
val elems: t -> (p state) list
|
||||
end = struct
|
||||
include Hashtbl.Make(struct
|
||||
open Hashtbl.Make(struct
|
||||
type t = p state
|
||||
let equal s1 s2 = s1.point = s2.point
|
||||
let hash s = Hashtbl.hash s.point
|
||||
end)
|
||||
type set =
|
||||
type nonrec t =
|
||||
{ h: int t
|
||||
; mutable next_id: int }
|
||||
let create () =
|
||||
{ h = create 500; next_id = 1 }
|
||||
let add set s =
|
||||
(* delete the check later *)
|
||||
assert (s.point = normalize s.point);
|
||||
try
|
||||
let id = find set.h s in
|
||||
|
@ -171,6 +188,8 @@ end = struct
|
|||
with Not_found -> begin
|
||||
let id = set.next_id in
|
||||
set.next_id <- id + 1;
|
||||
Printf.printf "adding: %d [%s]\n"
|
||||
id (show_pattern s.seen);
|
||||
add set.h s id;
|
||||
`Added, {s with id}
|
||||
end
|
||||
|
@ -198,17 +217,14 @@ end)
|
|||
type rule =
|
||||
{ name: string
|
||||
; pattern: pattern
|
||||
(* TODO access pattern *)
|
||||
}
|
||||
|
||||
let generate_table rl =
|
||||
let states = StateSet.create () in
|
||||
(* initialize states *)
|
||||
let ground =
|
||||
List.fold_left
|
||||
(fun ini r ->
|
||||
peel r.pattern r.name @ ini)
|
||||
[] rl |>
|
||||
List.concat_map
|
||||
(fun r -> peel r.pattern r.name) rl |>
|
||||
group_by_fst
|
||||
in
|
||||
let find x d l =
|
||||
|
@ -242,7 +258,7 @@ let generate_table rl =
|
|||
flag := `Stop;
|
||||
let statel = StateSet.elems states in
|
||||
iter_pairs statel (fun (sl, sr) ->
|
||||
nextbnr tmp sl sr |>
|
||||
next_binary tmp sl sr |>
|
||||
List.iter (fun (o, s') ->
|
||||
let flag', s' =
|
||||
StateSet.add states s' in
|
||||
|
|
|
@ -46,54 +46,52 @@ let ts =
|
|||
}
|
||||
|
||||
let print_sm =
|
||||
let op_str (k, o) =
|
||||
Printf.sprintf "%s%s"
|
||||
(match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul")
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
in
|
||||
StateMap.iter (fun k s' ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
let top =
|
||||
List.fold_left (fun top c ->
|
||||
match c with
|
||||
| Top r -> top ^ " " ^ r
|
||||
| _ -> top) "" s'.point
|
||||
in
|
||||
Printf.printf
|
||||
"(%s %d %d) -> %d\n"
|
||||
(op_str o)
|
||||
sl.id sr.id s'.id
|
||||
)
|
||||
"(%s %d %d) -> %d%s\n"
|
||||
(show_op o)
|
||||
sl.id sr.id s'.id top)
|
||||
|
||||
let address_rules =
|
||||
let rules =
|
||||
let oa = Kl, Oadd in
|
||||
let om = Kl, Omul in
|
||||
match `X64Addr with
|
||||
(* ------------------------------- *)
|
||||
| `X64Addr ->
|
||||
let rule name pattern =
|
||||
List.mapi (fun i pattern ->
|
||||
{ name = Printf.sprintf "%s%d" name (i+1)
|
||||
; pattern; })
|
||||
{ name (* = Printf.sprintf "%s%d" name (i+1) *)
|
||||
; pattern })
|
||||
(ac_equiv pattern) in
|
||||
|
||||
(* o + b *)
|
||||
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
||||
@ (* b + s * i *)
|
||||
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
|
||||
rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||
@ (* o + s * i *)
|
||||
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
|
||||
rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||
@ (* b + o + s * i *)
|
||||
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp)))
|
||||
rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
|
||||
(* ------------------------------- *)
|
||||
| `Add3 ->
|
||||
[ { name = "add"
|
||||
; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
|
||||
[ { name = "add"
|
||||
; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
|
||||
[ { name = "mul"
|
||||
; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
|
||||
Atm Tmp),
|
||||
Bnr (oa, Atm Tmp,
|
||||
Bnr (oa, Atm Tmp, Atm Tmp))) } ]
|
||||
|
||||
let sl, sm = generate_table address_rules
|
||||
|
||||
let sl, sm = generate_table rules
|
||||
let s n = List.find (fun {id; _} -> id = n) sl
|
||||
let () = print_sm sm
|
||||
|
||||
(*
|
||||
let tp0 =
|
||||
let o = Kw, Oadd in
|
||||
Bnr (o, Atm Tmp, Atm (Con 0L))
|
||||
let tp1 =
|
||||
let o = Kw, Oadd in
|
||||
Bnr (o, tp0, Atm (Con 1L))
|
||||
*)
|
||||
|
|
Loading…
Add table
Reference in a new issue