diff --git a/src/transformations/actual_net_apply.ml b/src/transformations/actual_net_apply.ml index 02c05df148d55c83471f1c091c2a5d222a48640f..c47c88e84774913b348ae07c7fc7b713001cad79 100644 --- a/src/transformations/actual_net_apply.ml +++ b/src/transformations/actual_net_apply.ml @@ -25,7 +25,10 @@ module G = Onnx.G (** TODO: - - 08-09-2022: write the proper semantic of equality *) + - 08-09-2022: write the proper semantic of Matmul, Add and Equality + - 09-09-2022: ensure that the graph traversal operations are properly done *) + +exception UnsupportedOperator of string let vars = Term.Hvs.create 100 let _lookup_vars = Term.Hvs.find_opt vars @@ -36,7 +39,7 @@ let create_var pfx id ty vars = Term.Hvs.add vars term id; term -let _create_fun_binop s ty = +let create_fun_binop s ty = let preid = Ident.id_fresh s in Term.create_fsymbol preid [ ty; ty ] ty @@ -45,43 +48,73 @@ let _declare_relu env = Theory.(ns_find_ls nn.mod_theory.th_export [ "relu" ]) (* To replace with proper symbols used in int.Int and ieee.Float64 *) -let sum_s ty = _create_fun_binop "+" ty +let sum_s ty = create_fun_binop "+" ty (* Declare equality between terms t1 and t2 *) -let _declare_eq t1 t2 = Term.t_equ t1 t2 +let declare_eq t1 t2 = Term.t_equ t1 t2 (* Declare a term defining equality between variables s1 and s2 *) -let _declare_eq_s s1 s2 = Term.t_equ (Term.t_var s1) (Term.t_var s2) +let declare_eq_s s1 s2 = Term.t_equ (Term.t_var s1) (Term.t_var s2) (* Term describing the sum of two variables v1 and v2 *) (* TODO: infer type of application*) -let _sum v1 v2 ty_vars = Term.t_app_infer (sum_s ty_vars) [ v1; v2 ] +let sum v1 v2 ty_vars = Term.t_app_infer (sum_s ty_vars) [ v1; v2 ] (* Let binding between terms t1 and t2 *) -let bind t1 t2 ty vars = +let _bind t1 t2 ty vars = let id = Term.Hvs.length vars and pfx = "EQ_" in Term.t_let_close (create_var pfx id ty vars) t1 t2 let _register_data_env _g _env = () -let id_on in_vars out_vars net_in_vars ty_vars = - match in_vars with - | x :: _ -> - List.init (List.length out_vars) ~f:(fun i -> - bind (Term.t_var @@ List.nth_exn out_vars i) (Term.t_var x) ty_vars vars) - | [] -> - List.init (List.length out_vars) ~f:(fun i -> - bind - (Term.t_var @@ List.nth_exn out_vars i) - (List.nth_exn net_in_vars 0) - ty_vars vars) - -let terms_of_nier g ty_inputs net_in_vars = - let ns = G.vertex_list g in +(* create terms defining the equality between two list of variables. + * Assuming equal size between in_vars and out_vars, + * resulting terms declare the equality between in_vars[i] + * and out_vars[i]*) +let id_on in_vars out_vars = + Stdio.printf "%d%!\n" (List.length in_vars); + Stdio.printf "%d%!\n" (List.length out_vars); + if List.length in_vars <> List.length out_vars + then + failwith + "Error, expecting same amount of variables before declaring equality" + else + let eq_terms = + List.foldi ~init:[] in_vars ~f:(fun i l in_var -> + declare_eq_s (List.nth_exn out_vars i) in_var :: l) + in + eq_terms + +(* create terms defining the element-wise addition between + * two list of variables and a data_node. Assuming equal size between + * in_vars and out_vars, resulting term declares out_vars[i] + * = in_vars + data[i] *) +let eltw_sum in_vars out_vars data_node ty_vars f = + let data = + match IR.Node.get_tensor data_node with Some t -> t | None -> assert false + in + let data_vars = + List.map + ~f:(fun v -> + let s_v = String.split ~on:'.' (f v) in + let v_int = List.nth_exn s_v 0 and v_frac = List.nth_exn s_v 1 in + Number.real_literal ~radix:10 ~neg:false ~int:v_int ~frac:v_frac + ~exp:None) + (IR.Tensor.flatten data) + in + match + (* TODO: morph real value into term*) + List.map3 out_vars in_vars data_vars ~f:(fun out_var in_var d -> + declare_eq (Term.t_var out_var) (sum in_var d ty_vars)) + with + | List.Or_unequal_lengths.Ok l -> l + | List.Or_unequal_lengths.Unequal_lengths -> + failwith "Error in element-wise sum: incoherent list length." + +let terms_of_nier g ty_inputs net_in_vars net_out_vars = let _, terms = - List.fold - ~init:([], [ [] ]) - ns + (* folding goes by increasing id order*) + G.fold_vertex (* We want to: * * get the input and output shape of each node * * declare input and output variables @@ -92,31 +125,53 @@ let terms_of_nier g ty_inputs net_in_vars = * variable (keep predecessor's output in the accumulator) * * *) - ~f:(fun (in_vars, l) n -> - let sh = - IR.Node.get_shape n - (*in let preds = match G.preds g n with*) - (* | x::y -> Some (x::y)*) - (* (*no preds: this is the*) (* * input node. handling*) (* * this - case separatly*)*) - (* | [] -> None*) - (* variable of the node. Assuming the shape here is - * the output shape of the node (it is not, but we - * have shape inferring function somewhere in - * ISAIEH. *) - in + (fun n (in_vars, l) -> + let open IR in + let sh = Node.get_shape n in + let node_id = n.id in + Stdio.printf "\n%s\n" (Node.show n Float.to_string); + (*in let preds = match G.preds g n with*) + (* | x::y -> Some (x::y)*) + (* (*no preds: this is the*) (* * input node. handling*) (* * this case + separatly*)*) + (* | [] -> None*) + (* variable of the node. Assuming the shape here is + * the output shape of the node (it is not, but we + * have shape inferring function somewhere in + * ISAIEH. *) let node_vs_out = List.init - ~f:(fun i -> create_var "toast_" i ty_inputs vars) - (List.length @@ List.concat (IR.Tensor.all_coords sh)) + ~f:(fun i -> + create_var ("out_" ^ Int.to_string node_id) i ty_inputs vars) + (List.length (Tensor.all_coords sh)) in let node_compute_terms = match IR.Node.get_op n with - (* for now, equality of all node - * variables with *) - | _ -> id_on in_vars node_vs_out net_in_vars + | Node.Matmul -> id_on in_vars node_vs_out + | Node.Add -> ( + match G.data_node_of n g with + | Some d_n -> eltw_sum in_vars node_vs_out d_n ty_inputs + | None -> failwith "Error, Add operator lacks a data node") + | Node.NO_OP -> + (* Multiple cases to consider: *) + (* If it is the input node, build variables + * using neuron input node *) + if Node.is_input_node n + then + id_on net_in_vars node_vs_out + (* If it is the output node, build variables + * using neuron output node *) + else id_on net_out_vars node_vs_out + | IR.Node.ReLu -> + (* TODO: implement using if then else*) + id_on in_vars node_vs_out + | op -> + raise + (UnsupportedOperator + (Fmt.str "Operator %s is not implemented." (IR.Node.str_op op))) in - (node_vs_out, node_compute_terms ty_inputs :: l)) + (node_vs_out, node_compute_terms :: l)) + g ([], []) in List.concat terms @@ -134,18 +189,22 @@ let actual_nn_flow _env = let input_vars = ref [ - Term.t_var (create_var "x" 0 ty_inputs vars); - Term.t_var (create_var "x" 1 ty_inputs vars); - Term.t_var (create_var "x" 2 ty_inputs vars); + create_var "x" 0 ty_inputs vars; + create_var "x" 1 ty_inputs vars; + create_var "x" 2 ty_inputs vars; ] in + let output_vars = + ref + [ create_var "y" 0 ty_inputs vars; create_var "y" 1 ty_inputs vars ] + in let g = let p = Onnx.parse nn_file in match p with | Error s -> Loc.errorm "%s" s | Ok (_model, nier) -> nier in - let cfg_terms = terms_of_nier g ty_inputs !input_vars in + let cfg_terms = terms_of_nier g ty_inputs !input_vars !output_vars in let cfg_term = Term.t_tuple cfg_terms in Pretty.print_term Fmt.stdout cfg_term; cfg_term)