Commit 24aa0851 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

more robust getter if the index has an additional dimension

parent 8e43a869
......@@ -29,13 +29,41 @@ module Tensor = struct
else create_ y (Dim (Array.init x ~f:(fun _ -> acc))) false
| [] -> {data=acc;shape=shape}
in create_ (List.rev shape) (Dim [||]) true
let unsqueeze ~sh1 ~sh2 =
let longest, shortest = match (List.length sh1) > (List.length sh2) with
| true -> sh1, sh2
| false -> sh2, sh1
in
(*find the index of the potential additional dimension*)
(* printf "%s" (show_shape longest); *)
(* printf "%s" (show_shape shortest); *)
let where_zero = match List.nth_exn longest 0 with
| 0 -> Some 0
| _ -> (match List.last_exn longest with
| 0 -> Some ((List.length longest)-1)
| _ -> None)
in match where_zero with
| Some idx ->(match List.sub longest ~pos:idx
~len:(List.length shortest) with
| [] -> None
| _ -> Some longest)
| None -> None
let get_idx t idx =
if List.length t.shape <> List.length idx then
failwith "error, index is too long for tensor shape"
else let rec get_idx_ tdata id = match tdata with
| Dim x -> get_idx_ x.(hd_id id) (tl_id id)
| Row x -> x.(List.hd_exn id)
in get_idx_ t.data idx
(* printf "\nLength of shape: %d" (List.length t.shape); *)
(* printf "\nshape of tensor: %s" (show_shape t.shape); *)
(* printf "\nidx: %s\n" (show_shape idx); *)
let true_idx = match List.length t.shape = List.length idx with
| true -> idx
| false -> (match unsqueeze ~sh1:idx ~sh2:t.shape with
| Some shape -> shape
| None -> failwith "error, index is too long for tensor shape")
in
(* printf "\ntrue idx: %s\n" (show_shape true_idx); *)
let rec get_idx_ tdata id = match tdata with
| Dim x -> get_idx_ x.(hd_id id) (tl_id id)
| Row x -> x.(List.hd_exn id)
in get_idx_ t.data true_idx
let set_idx t (id:shape) v =
if List.length t.shape <> List.length id then
failwith "error, index is too long for tensor shape"
......
......@@ -98,6 +98,13 @@ module Tensor : sig
exchanged. *)
val transpose_2d: t -> t
(** [unsqueeze sh1 sh2] returns the lowest common shape between
[sh1] and [sh2], and None if there is no common shape. A common shape
is when a shape of higher dimension has only 1 coordinates on non-shared
dimensions with the other. *)
val unsqueeze: sh1:shape -> sh2:shape -> shape option
end
(** {1 Modules for graph generation} *)
......
......@@ -206,7 +206,7 @@ let get_declarations_cfg (env:env_cfg) t =
let rec declare_aux a (d:vdata_cfg) flatenedl =
let rawd = d.data in
(* Printf.printf "declaring aux for variable %s\n" a; *)
(* Printf.printf "shape: %s\n" (stringify_int d.shape); *)
(* Printf.printf "shape in env: %s\n" (IR.Tensor.show_shape d.shape); *)
(* Printf.printf "scanning flatenedl\n"; *)
match flatenedl with
| [] -> ()
......@@ -1002,12 +1002,12 @@ let pp_graph_cfg g t =
| None -> true) (IR.vertex_list g)
in
let gCNode = List.rev gCNode_ and gInOut = List.rev gInOut_ in
printf "Length of gCNnode: %d \n%!" (List.length gCNode);
printf "Length of gInOut: %d \n%!" (List.length gInOut);
print_endline "gCNode:";
List.iter (fun x -> printf "%s, %d;" (t_name x) (x.IR.Vertex.id)) gCNode;
print_endline "\ngInOut:";
List.iter (fun x -> printf "%s " (t_name x)) gInOut;
(* printf "Length of gCNnode: %d \n%!" (List.length gCNode); *)
(* printf "Length of gInOut: %d \n%!" (List.length gInOut); *)
(* print_endline "gCNode:"; *)
(* List.iter (fun x -> printf "%s, %d;" (t_name x) (x.IR.Vertex.id)) gCNode; *)
(* print_endline "\ngInOut:"; *)
(* List.iter (fun x -> printf "%s " (t_name x)) gInOut; *)
List.iter (pp_IOnode_cfg env) gInOut;
(* pp_env_cfg env; *)
let cnodes = List.flatten (List.map (pp_cnode_cfg env t g) (gCNode)) in
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment