From 0ed4d922ab817423ca29b4c31115f5d2399790c2 Mon Sep 17 00:00:00 2001
From: Julien Girard <julien.girard2@cea.fr>
Date: Fri, 23 Sep 2022 11:35:08 +0200
Subject: [PATCH] [TRANS] Better naming for variables, added some
 documentation.

---
 src/transformations/actual_net_apply.ml | 226 +++++++++++++-----------
 1 file changed, 119 insertions(+), 107 deletions(-)

diff --git a/src/transformations/actual_net_apply.ml b/src/transformations/actual_net_apply.ml
index 5eb9e9e..c6febc8 100644
--- a/src/transformations/actual_net_apply.ml
+++ b/src/transformations/actual_net_apply.ml
@@ -14,7 +14,7 @@ exception UnsupportedOperator of string
 let vars = Term.Hvs.create 100
 let _lookup_vars = Term.Hvs.find_opt vars
 
-(* import the proper theory according to the types of
+(* Import the proper theory according to the types of
  * variables *)
 let theory_for ty_vars env =
   if Ty.ty_equal ty_vars Ty.ty_real
@@ -27,10 +27,6 @@ let create_var pfx id ty vars =
   Term.Hvs.add vars vsymbol id;
   vsymbol
 
-let _create_fun_binop s ty =
-  let preid = Ident.id_fresh s in
-  Term.create_fsymbol preid [ ty; ty ] ty
-
 (* Conversion from tensor float data to a constant real term *)
 let term_of_data d ty =
   let d_s = Printf.sprintf "%h" d in
@@ -54,17 +50,15 @@ let term_of_data d ty =
       let is_neg = Float.( <= ) d 0. in
       Number.real_literal ~radix:16 ~neg:is_neg ~int ~frac ~exp
   in
+  (* TODO: safe check whether the real value can be
+   * expressed in float. Fail if not, with an informative
+   * error message (eg: invalid float representation, maybe
+   * try this one with rounding?)*)
   Term.t_const
     (Constant.real_const ~pow2:rl.rl_real.rv_pow2 ~pow5:rl.rl_real.rv_pow5
        rl.rl_real.rv_sig)
     ty
 
-(* Declare equality between terms t1 and t2 *)
-let _declare_eq t1 t2 = Term.t_equ t1 t2
-
-(* Declare a term defining equality between variables s1 and s2 *)
-let _declare_eq_s s1 s2 = Term.t_equ (Term.t_var s1) (Term.t_var s2)
-
 (* Term describing the sum of two variables v1 and v2 with
  * their proper types *)
 let sum v1 v2 ty_vars env =
@@ -89,12 +83,13 @@ let mul v1 v2 ty_vars env =
 
 (* Bind variable v to term t in expression e *)
 let bind v ~t ~e = Term.t_let_close v t e
-let _register_data_env _g _env = ()
 
-(* Create terms defining the equality between two list of variables.
- * Assuming equal size between in_vars and out_vars,
- * resulting terms declare the equality between in_vars[i]
- * and out_vars[i]*)
+(* [id_on ~in_vars ~out_vars expr] creates a binding
+ * between two list of variables [in_vars] and [out_vars].
+ * First variables of both list are binded on expression
+ * expr, each subsequent bindings are added on top of the
+ * resulting expression.
+ *)
 let id_on ~in_vars ~out_vars expr =
   if List.length in_vars <> List.length out_vars
   then
@@ -107,10 +102,13 @@ let id_on ~in_vars ~out_vars expr =
     in
     eq_term
 
-(* Create terms defining the ReLU activation function
- * application between two list of variables.
- * Assuming equal size between in_vars and out_vars,
- * resulting terms defines in_vars[i] = relu(out_vars[i])
+(* [relu in_vars out_vars env expr ] creates terms defining
+ * the ReLU activation function application
+ * between two list of variables [in_vars] and
+ * [out_vars] on expression [expr].
+ * First variables of both list are binded on expression
+ * expr, each subsequent bindings are added on top of the
+ * resulting expression.
  * *)
 let relu ~in_vars ~out_vars env expr =
   if List.length in_vars <> List.length out_vars
@@ -125,14 +123,21 @@ let relu ~in_vars ~out_vars env expr =
     let eq_term =
       List.foldi ~init:expr in_vars ~f:(fun i e in_var ->
         let relu_on = Term.t_app_infer relu_s [ Term.t_var in_var ] in
-        bind (List.nth_exn in_vars i) ~t:relu_on ~e)
+        bind (List.nth_exn out_vars i) ~t:relu_on ~e)
     in
     eq_term
 
-(* Create terms defining the element-wise addition between
- * two list of variables and a data_node. Assuming equal size between
- * in_vars and out_vars, resulting term declares out_vars[i]
- * = in_vars + data[i] *)
+(* [etlw_sum in_vars out_vars data_node ty_vars env expr]
+ * creates terms defining the element-wise addition between
+ * two list of variables [in_vars] and [out_vars],
+ * a [data_node] holding the numerical
+ * value to add and an expression [expr].
+ * First variables of both list are binded on expression
+ * expr, each subsequent bindings are added on top of the
+ * resulting expression.
+ * Assuming equal size between
+ * in_vars and out_vars, resulting term declares
+ * let out_vars[i] = in_vars + data[i] in ... *)
 let eltw_sum ~in_vars ~out_vars data_node ty_vars env expr =
   let data =
     match IR.Node.get_tensor data_node with Some t -> t | None -> assert false
@@ -156,12 +161,20 @@ let eltw_sum ~in_vars ~out_vars data_node ty_vars env expr =
  * C = AB
  * C[i,j] = sum_k A[i;k] * B[k;j]
  *
- * Create terms defining the matrix multiplication between
- * two list of variables a_vars, c_vars and a data_node.
- * b_vars is built using datas stored in data_node.
+ * [matmul in_vars out_vars data_node in_shape out_shape ty_vars env expr]
+ * creates terms defining the matrix multiplication between
+ * two list of variables in_vars, out_vars and a data_node.
+ * This function relies on the following assumptions:
+ * * in_vars represents the cells of matrix A (a_vars)
+ * * data stored in data_node is used to build the cells of matrix B (b_vars)
+ * * out_vars represents the cells of matrix C (c_vars)
  * a_vars are the input variables
+ * b_vars the data variables
  * c_vars the output variables
- * c_vars[i,j] = sum_k a_vars[i,k] * b_vars[k,j] *)
+ * c_vars[i,j] = sum_k a_vars[i,k] * b_vars[k,j]
+ * First variables of both list are binded on expression
+ * expr, each subsequent bindings are added on top of the
+ * resulting expression.*)
 let matmul ~in_vars ~out_vars data_node ~in_shape ~out_shape ty_vars env expr =
   let data =
     match IR.Node.get_tensor data_node with Some t -> t | None -> assert false
@@ -176,8 +189,8 @@ let matmul ~in_vars ~out_vars data_node ~in_shape ~out_shape ty_vars env expr =
     | [] -> (i, j, t)
     | x :: y ->
       (* c[i,j] = sum_k a[i,k]*b[k,j]*)
-      (*a_var_range: all line of a *)
-      (*b_var_range: all column of b *)
+      (* a_var_range: all line of a *)
+      (* b_var_range: all column of b *)
       (* TODO: be sure that the common dimension is indeed
        * b_shape[0] *)
       let k_dim = Array.get b_shape 0 in
@@ -219,88 +232,86 @@ let matmul ~in_vars ~out_vars data_node ~in_shape ~out_shape ty_vars env expr =
   in
   terms
 
-let terms_of_nier g ty_inputs env ~output_vars ~input_vars =
+let terms_of_nier g ty_inputs env ~net_output_vars ~net_input_vars =
   IR.out_cfg_graph g;
+  (* Current NIER generation build the data nodes after the
+   * output variables, so we drop those since we will access
+   * those anyway later. *)
   let vs =
     let l = G.vertex_list g in
     List.drop_while ~f:(fun n -> not (IR.Node.is_output_node n)) l
   in
   let _, expr =
-    (* folding goes by decreasing id order, backward.
-     * Only accumulate new terms while going through the
-     * control flow; once the input node has been reached,
-     * only return the terms. *)
+    (* Folding goes by decreasing id order, backward.*)
     List.fold vs
       ~init:
-        ( (output_vars, IR.Node.get_shape @@ List.nth_exn vs 0, false),
-          Term.t_tuple @@ List.map ~f:Term.t_var output_vars )
-      ~f:(fun ((out_vars, out_shape, is_finished), expr) n ->
-        if is_finished
-        then (([], [||], true), expr)
-        else
-          let open IR in
-          let in_shape =
-            match Node.get_shape n with
-            | [||] -> G.infer_shape g n out_shape ~on_backward:true
-            | a -> a
-          in
-          let node_id = n.id in
-          let node_vs_in =
-            List.init
-              ~f:(fun i ->
-                create_var ("in_id_" ^ Int.to_string node_id) i ty_inputs vars)
-              (List.length (Tensor.all_coords in_shape))
-          in
-          let node_compute_term =
-            (* TODO: axiomatize the resulting term using
-             * let d = Decl.create_prop_decl Paxiom ps t *)
-            match IR.Node.get_op n with
-            | Node.Matmul -> (
-              match G.data_node_of n g with
-              | Some d_n ->
-                matmul ~in_vars:node_vs_in ~out_vars d_n ~in_shape ~out_shape
-                  ty_inputs env expr
-              | None -> failwith "Error, Matmul operator lacks a data node")
-            | Node.Add -> (
-              match G.data_node_of n g with
-              | Some d_n ->
-                eltw_sum ~out_vars ~in_vars:node_vs_in d_n ty_inputs env expr
-              | None -> failwith "Error, Add operator lacks a data node")
-            | Node.NO_OP ->
-              (* If it is the input node, build variables
-               * using neuron input node *)
-              if Node.is_input_node n
-              then
-                id_on ~out_vars:input_vars ~in_vars:node_vs_in expr
-                (* If it is the output node, the resulting
-                 * term is the tuple of the output variables
-                 * of the net;
-                 * backpropagate those to the previous layer*)
-              else if Node.is_output_node n
-              then
-                id_on ~out_vars:output_vars ~in_vars:node_vs_in
-                  (Term.t_tuple @@ List.map ~f:Term.t_var output_vars)
-              else expr
-            | IR.Node.ReLu -> relu ~out_vars ~in_vars:node_vs_in env expr
-            | op ->
-              raise
-                (UnsupportedOperator
-                   (Fmt.str
-                      "Operator %s is not implemented for actual_net_apply."
-                      (IR.Node.str_op op)))
-          in
-          ((node_vs_in, out_shape, false), node_compute_term))
+        ( (net_output_vars, IR.Node.get_shape @@ List.nth_exn vs 0),
+          Term.t_tuple @@ List.map ~f:Term.t_var net_output_vars )
+      ~f:(fun ((v_out_vars, out_shape), expr) v ->
+        let open IR in
+        let in_shape =
+          match Node.get_shape v with
+          | [||] -> G.infer_shape g v out_shape ~on_backward:true
+          | a -> a
+        in
+        let v_id = v.id in
+        let v_in_vars =
+          List.init
+            ~f:(fun i ->
+              create_var ("n_id_" ^ Int.to_string v_id ^ "_") i ty_inputs vars)
+            (List.length (Tensor.all_coords in_shape))
+        in
+        let v_term =
+          (* TODO: axiomatize the resulting term using
+           * let d = Decl.create_prop_decl Paxiom ps t *)
+          match IR.Node.get_op v with
+          | Node.Matmul -> (
+            match G.data_node_of v g with
+            | Some d_n ->
+              matmul ~in_vars:v_in_vars ~out_vars:v_out_vars d_n ~in_shape
+                ~out_shape ty_inputs env expr
+            | None -> failwith "Error, Matmul operator lacks a data node")
+          | Node.Add -> (
+            match G.data_node_of v g with
+            | Some d_n ->
+              eltw_sum ~out_vars:v_out_vars ~in_vars:v_in_vars d_n ty_inputs env
+                expr
+            | None -> failwith "Error, Add operator lacks a data node")
+          | Node.NO_OP ->
+            (* If it is the input vertex, bind neural network
+             * input variables to the vertex output node. *)
+            if Node.is_input_node v
+            then
+              id_on ~out_vars:v_out_vars ~in_vars:net_input_vars expr
+              (* If it is the output vertex, the resulting
+               * term is the tuple of the output variables
+               * of the net;
+               * backpropagate those to the previous layer. *)
+            else if Node.is_output_node v
+            then
+              id_on ~out_vars:net_output_vars ~in_vars:v_in_vars
+                (Term.t_tuple @@ List.map ~f:Term.t_var net_output_vars)
+            else expr
+          | IR.Node.ReLu ->
+            relu ~out_vars:v_out_vars ~in_vars:v_in_vars env expr
+          | op ->
+            raise
+              (UnsupportedOperator
+                 (Fmt.str "Operator %s is not implemented for actual_net_apply."
+                    (IR.Node.str_op op)))
+        in
+        ((v_in_vars, out_shape), v_term))
   in
   expr
 
 (* Create logic symbols for input variables and replace
  * nnet_apply by control flow terms. *)
 let actual_nn_flow env =
-  let rec aux (term : Term.term) =
+  let rec substitute_net_apply (term : Term.term) =
     match term.t_node with
-    | Term.Tapp (ls, _args) -> (
+    | Term.Tapp (ls, args) -> (
       match Language.lookup_loaded_nets ls with
-      | None -> Term.t_map aux term
+      | None -> Term.t_map substitute_net_apply term
       | Some nn ->
         let g =
           match nn.nier with
@@ -308,24 +319,25 @@ let actual_nn_flow env =
           | None -> failwith "Error, call this transform only on an ONNX NN."
         in
         let ty_inputs = nn.ty_data in
+        let net_input_vars =
+          List.map args ~f:(fun x ->
+            (*net_apply should always be followed by a
+             * non-empty list of arguments*)
+            match x.Term.t_node with Tvar ts -> ts | _ -> assert false)
+        in
         let cfg_term =
-          terms_of_nier g ty_inputs env
-            ~output_vars:
+          terms_of_nier g ty_inputs env (* TODO: how to get those? *)
+            ~net_output_vars:
               [
                 create_var "y" 1 ty_inputs vars; create_var "y" 2 ty_inputs vars;
               ]
-            ~input_vars:
-              [
-                create_var "x" 1 ty_inputs vars;
-                create_var "x" 2 ty_inputs vars;
-                create_var "x" 3 ty_inputs vars;
-              ]
+            ~net_input_vars
         in
         Stdio.printf "\nObtained term: \n%!";
         Pretty.print_term Fmt.stdout cfg_term;
         Stdio.printf "\n%!";
         cfg_term)
-    | _ -> Term.t_map aux term
+    | _ -> Term.t_map substitute_net_apply term
   in
   Trans.fold
     (fun task_hd task ->
@@ -335,7 +347,7 @@ let actual_nn_flow env =
         let decl =
           Decl.decl_map
             (fun term ->
-              let term = aux term in
+              let term = substitute_net_apply term in
               term)
             decl
         in
-- 
GitLab