Revision 6ee5d28a8988618b187e10ebbe680be736c694c4 authored by fclement on 21 September 2006, 15:57:14 UTC, committed by fclement on 21 September 2006, 15:57:14 UTC
1 parent 4442830
Raw File
pa_p3l.ml
(* camlp4r *)
(* $Id: pa_p3l.ml,v 1.3 2005-07-21 18:37:08 weis Exp $ *)

#load "pa_extend.cmo";
#load "q_MLast.cmo";

open Pcaml;

type body_index =
  [ BIident of string
  | BIrange of string and string and string
  | BIempty ]
;

type body_param =
  [ BParray of MLast.loc and string and list body_index
  | BPvar of MLast.loc and string ]
;

value make_expr_list _loc el =
  List.fold_right (fun e el -> <:expr< [$e$ :: $el$] >>) el <:expr< [] >>
;

value make_patt_list _loc pl =
  List.fold_right (fun p pl -> <:patt< [$p$ :: $pl$] >>) pl <:patt< [] >>
;

value dim_fun_name pl =
  let (n, _) =
    List.fold_left
      (fun (n, iname) p ->
         let n =
           match p with
           [ BIident _ -> n ^ String.make 1 iname
           | BIrange _ _ _ -> n ^ "x"
           | BIempty -> n ^ "_" ]
         in
         (n, Char.chr (Char.code iname + 1)))
      ("get", 'i') pl
  in
  n
;

value dim_fun_body _loc dlist =
  (* building params names *)
  let (revl, _) =
    List.fold_left
      (fun (l, cnt) d ->
         let n = String.make 1 (Char.chr (Char.code 'i' + cnt)) in
         ([(d, n) :: l], cnt + 1))
      ([], 0) dlist
  in
  (* building creation of sub-matrices *)
  let (e, _) =
    List.fold_left
      (fun (e, revl) (d, iname) ->
         match d with
         [ BIident _ ->
             let e =
               match revl with
               [ Some revl ->
                   List.fold_right
                     (fun (_, ind) e -> <:expr< $e$ .($lid:ind$) >>)
                     revl e
               | None -> e ]
             in
             (e, None)
         | BIrange _ _ _ ->
             let e =
               match revl with
               [ Some revl ->
                   List.fold_right
                     (fun (_, ind) e -> <:expr< $e$ .($lid:ind$) >>)
                     revl e
               | None -> e ]
             in
             let di = "d" ^ iname in
             let ilen = iname ^ "len" in
             let n = "n" ^ iname in
             let e =
               <:expr<
                  Array.init $lid:ilen$
                    (fun $lid:iname$ ->
                       let $lid:iname$ =
                         let v = ($lid:iname$ + $lid:di$) mod $lid:n$ in
                         if v < 0 then v + $lid:n$ else v
                       in
                       $e$)
               >>
             in
             (e, None)
         | BIempty ->
             match revl with
             [ Some [_ :: revl] -> (e, Some revl)
             | Some _ -> assert False
             | None ->
                 let nname = "n" ^ iname in
                 let e =
                   <:expr< Array.init $lid:nname$ (fun $lid:iname$ -> $e$) >>
                 in
                 (e, None) ] ])
      (<:expr< a >>, Some revl) revl
  in
  (* adding parameters *)
  let e = <:expr< fun a -> $e$ >> in
  let (e, _) =
    List.fold_left
      (fun (e, notemptyfound) (d, iname) ->
         match d with
         [ BIident _ -> (<:expr< fun $lid:iname$ -> $e$ >>, True)
         | BIrange _ _ _ ->
             let di = "d" ^ iname in
             let ilen = iname ^ "len" in
             let n = "n" ^ iname in
             (<:expr< fun ($lid:di$, $lid:ilen$, $lid:n$) -> $e$ >>, True)
         | BIempty ->
             if notemptyfound then
               let nname = "n" ^ iname in
               (<:expr< fun $lid:nname$ -> $e$ >>, True)
             else (e, False) ])
      (e, False) revl
  in
  e
;

value rec get_dims _loc id =
  fun
  [ [((_, id1), diml) :: idef] ->
      if id = id1 then diml else get_dims _loc id idef
  | [] ->
      Stdpp.raise_with_loc _loc (Failure ("unbound variable " ^ id)) ]
;

value nth_dim _loc id cnt dims =
  try List.nth dims (cnt - 1) with
  [ Failure _ ->
      Stdpp.raise_with_loc _loc
        (Failure ("dimension inconsistent with definition of '" ^ id ^ "'")) ]
;

value dim_fun_call idef _loc id pl =
  let dims = get_dims _loc id idef in
  let e = <:expr< $lid:dim_fun_name pl$ >> in
  let (list, _, _) =
    let revl = List.rev pl in
    List.fold_left
      (fun (list, cnt, notemptyfound) p ->
         match p with
         [ BIident i -> ([<:expr< $lid:i$ >> :: list], cnt - 1, True)
         | BIrange i d1 d2 ->
             let n = nth_dim _loc id cnt dims in
             let len =
               string_of_int (int_of_string d2 - int_of_string d1 + 1)
             in
             let ei =
               let d1 = int_of_string d1 in
               if d1 < 0 then
                 <:expr< $lid:i$ - $int:string_of_int (-d1)$ >>
               else
                 <:expr< $lid:i$ + $int:string_of_int d1$ >>
             in
             ([<:expr< ($ei$, $int:len$, $int:n$) >> :: list],
              cnt - 1, True)
         | BIempty ->
             if notemptyfound then
               let n = nth_dim _loc id cnt dims in
               ([<:expr< $int:n$ >> :: list], cnt - 1, True)
             else (list, cnt - 1, False) ])
      ([], List.length revl, False) revl
  in
  let e = List.fold_left (fun e1 e2 -> <:expr< $e1$ $e2$ >>) e list in
  <:expr< $e$ $lid:id$ >>
;  

value make_gath_array _loc (_, pl) init =
  List.fold_right
    (fun p e ->
      let sz = <:expr< $int:p$ >> in
       match e with
       [ <:expr< 0 >> -> <:expr< Array.create $sz$ $init$ >>
       | _ -> <:expr< Array.init $sz$ (fun _ -> $e$) >> ])
    pl <:expr< 0 >>
;

value gath_param_names (_, pl) =
  let (l, _) =
    List.fold_left
      (fun (l, cnt) _ -> (["x" ^ string_of_int cnt :: l], cnt + 1))
      ([], 1) pl
  in
  List.rev l
;

value gath_param _loc odef =
  let sl = gath_param_names odef in
  let pl = List.map (fun n -> <:patt< $lid:n$ >>) sl in
  make_patt_list _loc pl
;

value gath_acc _loc odef =
  let sl = gath_param_names odef in
  List.fold_left (fun z s -> <:expr< $z$.($lid:s$) >>) <:expr< z >> sl
;

EXTEND
  GLOBAL: expr;
  expr: LEVEL "apply"
    [ [ "P3LMap"; "in"; "("; idef = in_def; ")";
        LIDENT "out"; "("; odef = out_def; oinit = out_init; ")";
        LIDENT "nworkers"; "["; w = workers_def; "]"; LIDENT "body"; "in";
        "("; ibod = body_in_def; ")"; LIDENT "out"; "("; obod = body_out_def;
        ")"; body = expr; "end" ->
          let odim =
            let el = List.map (fun i -> <:expr< $int:i$ >>) (snd odef) in
            make_expr_list _loc el
          in
          let nwork =
            let el = List.map (fun (_, n) -> <:expr< $int:n$ >>) w in
            make_expr_list _loc el
          in
          let scat_prm =
            let pl =
              List.map (fun ((_loc, id), _) -> <:patt< $lid:id$ >>) idef
            in
            match pl with
            [ [p] -> p
            | _ -> <:patt< ($list:pl$) >> ]
          in
          let scat_body =
            let p =
              let pl =
                List.map (fun ((_loc, id), _) -> <:patt< $lid:id$ >>) w
              in
              make_patt_list _loc pl
            in
            let el =
              List.map
                (fun
                 [ BParray _loc id pl -> dim_fun_call idef _loc id pl
                 | BPvar _loc id -> <:expr< $lid:id$ >> ])
                ibod
            in
            let e =
              match el with
              [ [e] -> e
              | _ -> <:expr< ($list:el$) >> ]
            in
            <:expr< fun [ $p$ -> $e$ | _ -> invalid_arg "scatter" ] >>
          in
          let needed_fun =
            loop ibod where rec loop =
              fun
              [ [BParray _ _ pl :: l] ->
                  let gd =
                    List.map
                      (fun
                       [ BIident _ -> BIident ""
                       | BIrange _ _ _ -> BIrange "" "" ""
                       | BIempty -> BIempty ])
                      pl
                  in
                  let r = loop l in
                  if List.mem gd r then r else [gd :: r]
              | [BPvar _ _ :: l] -> loop l
              | [] -> [] ]
          in
          let gath_body =
            <:expr<
              let z = $make_gath_array _loc odef oinit$ in
              do {
                List.iter
                  (fun
                   [ ($gath_param _loc odef$, v) ->
                       $gath_acc _loc odef$ := v
                   | _ -> invalid_arg "gather" ])
                  l;
                z
              }
            >>
          in
          let e =
            <:expr<
               let scatter $scat_prm$ = $scat_body$ in
               let gather l = $gath_body$ in
               matrix_map ($odim$, $nwork$, $body$, scatter, gather)
            >>
          in
          List.fold_right
            (fun nf e ->
               let nfp = <:patt< $lid:dim_fun_name nf$ >> in
               let nfb = dim_fun_body _loc nf in
               <:expr< let $nfp$ = $nfb$ in $e$ >>)
            needed_fun e ] ]
  ;
  in_def:
    [ [ l = LIST0 param_def SEP "," -> l ] ]
  ;
  out_def:
    [ [ p = param_def -> p ] ]
  ;
  out_init:
    [ [ ","; LIDENT "init"; "="; e = expr -> e
      | -> <:expr< 0 >> ] ]
  ;
  param_def:
    [ [ id = lident; pl = LIST0 index -> (id, pl) ] ]
  ;
  index:
    [ [ "["; i = INT; "]" -> i ] ]
  ;
  workers_def:
    [ [ wpl = LIST1 worker_param_def SEP ";" -> wpl ] ]
  ;
  worker_param_def:
    [ [ "*"; id = lident; "="; i = INT -> (id, i) ] ]
  ;
  body_in_def:
    [ [ l = LIST0 body_param_def SEP "," -> l ] ]
  ;
  body_out_def:
    [ [ p = body_param_def -> p ] ]
  ;
  body_param_def:
    [ [ (_loc1, id) = lident; pl = LIST0 body_index -> BParray _loc1 id pl
      | "*"; id = LIDENT -> BPvar _loc id ] ]
  ;
  body_index:
    [ [ "["; "]" -> BIempty
      | "["; "*"; id = LIDENT; "]" -> BIident id
      | "["; "*"; id = LIDENT; d1 = body_index_incr; ":"; "*";
        (_loc1, id1) = lident; d2 = body_index_incr; "]" ->
          if id1 <> id then
            Stdpp.raise_with_loc _loc1 (Failure ("'" ^ id ^ "' expected"))
          else BIrange id d1 d2 ] ]
  ;
  body_index_incr:
    [ [ "+"; d = INT -> d
      | "-"; d = INT -> "-" ^ d
      | -> "0" ] ]
  ;
  lident:
    [ [ id = LIDENT -> (_loc, id) ] ]
  ;
END;
back to top