From 3bc9d084cade3f979d561d68da12e9c2a284c3fd Mon Sep 17 00:00:00 2001
From: Aymeric Varasse <aymeric.varasse@cea.fr>
Date: Wed, 19 Apr 2023 14:21:31 +0200
Subject: [PATCH] [aimos] Create config file from improved predicate

---
 examples/acasxu/property_5_aimos.why |  2 +-
 src/AIMOS.ml                         | 65 +++++++++++++++++++++++++---
 src/AIMOS.mli                        |  1 -
 src/dataset.ml                       | 47 ++++++++++++++++++--
 src/dataset.mli                      | 10 ++++-
 src/verification.ml                  | 18 ++------
 stdlib/caisar.mlw                    |  6 ++-
 tests/aimos.t                        |  2 +-
 8 files changed, 122 insertions(+), 29 deletions(-)

diff --git a/examples/acasxu/property_5_aimos.why b/examples/acasxu/property_5_aimos.why
index 39e1305..937e063 100644
--- a/examples/acasxu/property_5_aimos.why
+++ b/examples/acasxu/property_5_aimos.why
@@ -4,5 +4,5 @@ theory ACASXU_P5
   use caisar.DatasetClassification
   use caisar.DatasetClassificationProps
 
-  goal G: meta_robust model dataset (1.0:t)
+  goal G: meta_robust model dataset (1.0:t) ("reluplex_rotation":string) (0:int) (0:int) (0:int)
 end
diff --git a/src/AIMOS.ml b/src/AIMOS.ml
index 50d2caf..1148e5a 100644
--- a/src/AIMOS.ml
+++ b/src/AIMOS.ml
@@ -25,16 +25,71 @@ open Base
 
 let aimos_file = Re__Core.(compile (str "%{aimos_file}"))
 
-let build_command config_prover aimos_filename =
+let write_config (inputs_path : string) (models_path : string)
+  (perturbation : string) (amplitude : Dataset.amplitude) =
+  let config_file = Caml.Filename.temp_file "aimos-" ".yml" in
+  let config_path = Fpath.v config_file in
+  let ampli = Dataset.string_of_amplitude amplitude in
+  let transformations : Yaml.value =
+    match ampli with
+    | None -> `A [ `O [ ("name", `String perturbation) ] ]
+    | Some a ->
+      `A [ `O [ ("name", `String perturbation); ("fn_range", `String a) ] ]
+  in
+  let options : Yaml.value =
+    `O
+      [
+        ("plot", `Bool false);
+        ("inputs_path", `String inputs_path);
+        ("transformations", transformations);
+        ("custom_t_path", `String "config/custom_transformations.py");
+      ]
+  in
+  let models : Yaml.value =
+    `A
+      [
+        `O
+          [
+            ( "defaults",
+              `O
+                [
+                  ("models_path", `String models_path);
+                  ("out_mode", `String "classif_min");
+                ] );
+          ];
+      ]
+  in
+  let full_config : Yaml.value =
+    `O [ ("options", options); ("models", models) ]
+  in
+  Yaml_unix.to_file_exn config_path full_config;
+  config_file
+
+let build_command config_prover
+  (predicate : (Language.nn_shape, string) Dataset.predicate) =
+  let dataset = predicate.dataset in
+  let inputs_path, models_path =
+    (Unix.realpath dataset, Unix.realpath predicate.model.filename)
+  in
+  let perturbation, (amplitude : Dataset.amplitude) =
+    match predicate.property with
+    | Dataset.MetaRobust (_, p, start, stop, step) ->
+      (p, { start = Some start; stop; step = Some step })
+    | _ -> failwith "Unsupported property"
+  in
+  let aimos_config =
+    write_config inputs_path models_path perturbation amplitude
+  in
   let command = Whyconf.get_complete_command ~with_steps:false config_prover in
-  Re__Core.replace_string aimos_file ~by:aimos_filename command
+  Re__Core.replace_string aimos_file ~by:aimos_config command
 
 let re_aimos_output = Re__Pcre.regexp "((,\\s)(\\d+\\.\\d+))"
 
 let build_answer predicate_kind prover_result =
   let threshold =
     match predicate_kind with
-    | Dataset.MetaRobust f -> Float.of_string (Dataset.string_of_threshold f)
+    | Dataset.MetaRobust (f, _, _, _, _) ->
+      Float.of_string (Dataset.string_of_threshold f)
     | _ -> failwith "Unsupported property"
   in
   match prover_result.Call_provers.pr_answer with
@@ -61,8 +116,8 @@ let build_answer predicate_kind prover_result =
     assert false
 
 let call_prover limit config config_prover
-  (predicate : (Language.nn_shape, string) Dataset.predicate) aimos_config =
-  let command = build_command config_prover aimos_config in
+  (predicate : (Language.nn_shape, string) Dataset.predicate) =
+  let command = build_command config_prover predicate in
   let prover_call =
     let res_parser =
       {
diff --git a/src/AIMOS.mli b/src/AIMOS.mli
index 724e072..411b0da 100644
--- a/src/AIMOS.mli
+++ b/src/AIMOS.mli
@@ -27,5 +27,4 @@ val call_prover :
   Whyconf.main ->
   Whyconf.config_prover ->
   (Language.nn_shape, string) Dataset.predicate ->
-  string ->
   Call_provers.prover_answer
diff --git a/src/dataset.ml b/src/dataset.ml
index 5221dd9..bbb7878 100644
--- a/src/dataset.ml
+++ b/src/dataset.ml
@@ -51,13 +51,36 @@ let term_of_eps env eps =
 
 type threshold = float [@@deriving yojson, show]
 
+type amplitude = {
+  start : int option;
+  stop : int;
+  step : int option;
+}
+[@@deriving yojson]
+
 let string_of_threshold threshold = Float.to_string threshold
 
+let string_of_amplitude amplitude =
+  let start = match amplitude.start with Some a -> a | None -> 0 in
+  if start = amplitude.stop
+  then None
+  else
+    let start_str =
+      Option.value_map amplitude.start ~default:"" ~f:Int.to_string
+    in
+    let step_str =
+      Option.value_map amplitude.step ~default:"" ~f:Int.to_string
+    in
+    Some
+      (Fmt.str "range(%s, %s, %s)" start_str
+         (Int.to_string amplitude.stop)
+         step_str)
+
 type property =
   | Correct
   | Robust of eps
   | CondRobust of eps
-  | MetaRobust of threshold
+  | MetaRobust of threshold * string * int * int * int
 [@@deriving yojson, show]
 
 type ('a, 'b) predicate = {
@@ -82,6 +105,7 @@ let failwith_unsupported_term t =
 let failwith_unsupported_ls ls =
   failwith (Fmt.str "Unsupported logic symbol %a" Pretty.print_ls ls)
 
+(* TODO: use amplitude type instead of 3 ints *)
 let interpret_predicate env ~on_model ~on_dataset task =
   let task = Trans.apply Introduction.introduce_premises task in
   let term = Task.task_goal_fmla task in
@@ -101,14 +125,29 @@ let interpret_predicate env ~on_model ~on_dataset task =
       | [ { t_node = Tconst (Constant.ConstReal e); _ } ] ->
         let robust_predicate = find_predicate_ls env "robust" in
         let cond_robust_predicate = find_predicate_ls env "cond_robust" in
-        let meta_robust_predicate = find_predicate_ls env "meta_robust" in
         let f = float_of_real_constant e in
         if Term.ls_equal ls robust_predicate
         then Robust f
         else if Term.ls_equal ls cond_robust_predicate
         then CondRobust f
-        else if Term.ls_equal ls meta_robust_predicate
-        then MetaRobust f
+        else failwith_unsupported_ls ls
+      | [
+       { t_node = Tconst (Constant.ConstReal e); _ };
+       { t_node = Tconst (Constant.ConstStr p); _ };
+       { t_node = Tconst (Constant.ConstInt start); _ };
+       { t_node = Tconst (Constant.ConstInt stop); _ };
+       { t_node = Tconst (Constant.ConstInt step); _ };
+      ] ->
+        let meta_robust_predicate = find_predicate_ls env "meta_robust" in
+        let f = float_of_real_constant e in
+        if Term.ls_equal ls meta_robust_predicate
+        then
+          MetaRobust
+            ( f,
+              p,
+              Number.to_small_integer start,
+              Number.to_small_integer stop,
+              Number.to_small_integer step )
         else failwith_unsupported_ls ls
       | _ -> failwith_unsupported_term term
     in
diff --git a/src/dataset.mli b/src/dataset.mli
index 13efaec..455374b 100644
--- a/src/dataset.mli
+++ b/src/dataset.mli
@@ -30,13 +30,21 @@ val term_of_eps : Env.env -> eps -> Term.term
 
 type threshold [@@deriving yojson, show]
 
+type amplitude = {
+  start : int option;
+  stop : int;
+  step : int option;
+}
+[@@deriving yojson]
+
 val string_of_threshold : threshold -> string
+val string_of_amplitude : amplitude -> string option
 
 type property = private
   | Correct
   | Robust of eps
   | CondRobust of eps
-  | MetaRobust of threshold
+  | MetaRobust of threshold * string * int * int * int
 [@@deriving yojson, show]
 
 type ('model, 'dataset) predicate = private {
diff --git a/src/verification.ml b/src/verification.ml
index 1a7162e..2d806e9 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -123,7 +123,7 @@ let answer_saver limit config env config_prover dataset task =
     let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in
     (prover_answer, Generic (Some additional_info))
 
-let answer_aimos limit config env config_prover dataset task aimos_config =
+let answer_aimos limit config env config_prover dataset task =
   let predicate =
     let on_model ls =
       Option.value_or_thunk
@@ -145,17 +145,7 @@ let answer_aimos limit config env config_prover dataset task aimos_config =
     in
     Dataset.interpret_predicate env ~on_model ~on_dataset task
   in
-  let aimos_filename =
-    match aimos_config with
-    | None ->
-      List.map Dirs.Sites.config ~f:(fun dir ->
-        Caml.Filename.concat dir "aimos_config.yml")
-      |> List.find_exn ~f:Caml.Sys.file_exists
-    | Some s -> s
-  in
-  let answer =
-    AIMOS.call_prover limit config config_prover predicate aimos_filename
-  in
+  let answer = AIMOS.call_prover limit config config_prover predicate in
   let additional_info = Generic None in
   (answer, additional_info)
 
@@ -262,9 +252,7 @@ let call_prover ?dataset ~limit config env prover config_prover driver task =
   let prover_answer, additional_info =
     match prover with
     | Prover.Saver -> answer_saver limit config env config_prover dataset task
-    | Aimos ->
-      (* TODO: add real config file *)
-      answer_aimos limit config env config_prover dataset task None
+    | Aimos -> answer_aimos limit config env config_prover dataset task
     | (Marabou | Pyrat | Nnenum) when Option.is_some dataset ->
       let dataset = Unix.realpath (Option.value_exn dataset) in
       answer_dataset limit config env prover config_prover driver dataset task
diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw
index 13583bb..47d6820 100644
--- a/stdlib/caisar.mlw
+++ b/stdlib/caisar.mlw
@@ -24,11 +24,14 @@ theory DatasetClassification
   use ieee_float.Float64
   use int.Int
   use array.Array
+  use option.Option
 
   type features = array t
   type label_ = int
   type record = (features, label_)
   type dataset = array record
+  type ampli = { start: option int; stop: int; step: option int; }
+  type amplitude = option ampli
 
   constant dataset: dataset
 
@@ -69,7 +72,8 @@ theory DatasetClassificationProps
   predicate cond_robust (m: model) (d: dataset) (eps: t) =
     correct m d /\ robust m d eps
 
-  predicate meta_robust (m: model) (d: dataset) (threshold: t)
+  predicate meta_robust (m: model) (d: dataset) (threshold: t) (perturbation: string) (start: int) (stop: int) (step: int)
+  (* predicate meta_robust (m: model) (d: dataset) (threshold: t) (perturbation: string) (ampli: amplitude) *)
 end
 
 theory NN
diff --git a/tests/aimos.t b/tests/aimos.t
index db58028..99db4d5 100644
--- a/tests/aimos.t
+++ b/tests/aimos.t
@@ -19,7 +19,7 @@ Test verify
   >   use caisar.DatasetClassification
   >   use caisar.DatasetClassificationProps
   > 
-  >   goal G: meta_robust model dataset (1.0:t)
+  >   goal G: meta_robust model dataset (1.0:t) ("reluplex_rotation":string) (0:int) (0:int) (0:int)
   > end
   > EOF
   [caisar] Goal G: Valid
-- 
GitLab