From 558f5d173dbaa5f390594e1759942cce745e4116 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 5 Oct 2022 18:42:55 +0200
Subject: [PATCH] [dataset][SAVer] Rework types and SAVer's call prover API.

---
 src/SAVer.ml        | 32 ++++++++++++++++++--------------
 src/SAVer.mli       |  6 ++----
 src/dataset.ml      | 29 ++++++++++++++---------------
 src/dataset.mli     | 12 +++++++++---
 src/language.mli    |  4 ++--
 src/verification.ml | 10 +++++++++-
 6 files changed, 54 insertions(+), 39 deletions(-)

diff --git a/src/SAVer.ml b/src/SAVer.ml
index 64958a1..b636aa9 100644
--- a/src/SAVer.ml
+++ b/src/SAVer.ml
@@ -75,11 +75,13 @@ let re_group_number_of_predicate = function
   | Robust _ -> 3
   | CondRobust _ -> 4
 
-let negative_prover_answer_of_predicate = function
-  | Dataset.Correct -> Call_provers.Invalid
+let negative_answer = function
+  | Dataset.Correct ->
+    (* SAVer is complete wrt correct predicate. *)
+    Call_provers.Invalid
   | Robust _ | CondRobust _ -> Call_provers.Unknown ""
 
-let build_answer predicate prover_result =
+let build_answer predicate_kind prover_result =
   match prover_result.Call_provers.pr_answer with
   | Call_provers.HighFailure -> (
     let pr_output = prover_result.pr_output in
@@ -89,13 +91,13 @@ let build_answer predicate prover_result =
         Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size)
       in
       let nb_proved =
-        let re_group_number = re_group_number_of_predicate predicate in
+        let re_group_number = re_group_number_of_predicate predicate_kind in
         Int.of_string (Re__Core.Group.get re_group re_group_number)
       in
       let prover_answer =
         if nb_total = nb_proved
         then Call_provers.Valid
-        else negative_prover_answer_of_predicate predicate
+        else negative_answer predicate_kind
       in
       { prover_answer; nb_total; nb_proved }
     | None -> failwith "Cannot interpret the output provided by SAVer")
@@ -104,15 +106,17 @@ let build_answer predicate prover_result =
        anything in SAVer's driver. *)
     assert false
 
-let call limit config env config_prover ~dataset task =
-  let dataset_task = Dataset.interpret env Language.lookup_loaded_svms task in
-  let svm = dataset_task.model.filename in
-  let eps =
-    match dataset_task.predicate with
-    | Dataset.Correct -> None
-    | Robust e | CondRobust e -> Some e
+let call_prover limit config config_prover predicate =
+  let command =
+    let svm = predicate.Dataset.model.Language.filename in
+    let dataset = predicate.dataset in
+    let eps =
+      match predicate.kind with
+      | Dataset.Correct -> None
+      | Robust e | CondRobust e -> Some e
+    in
+    build_command config_prover svm dataset eps
   in
-  let command = build_command config_prover svm dataset eps in
   let prover_call =
     let res_parser =
       {
@@ -130,4 +134,4 @@ let call limit config env config_prover ~dataset task =
       ~printing_info:Printer.default_printing_info (Buffer.create 10)
   in
   let prover_result = Call_provers.wait_on_call prover_call in
-  build_answer dataset_task.predicate prover_result
+  build_answer predicate.kind prover_result
diff --git a/src/SAVer.mli b/src/SAVer.mli
index 2119b52..489d223 100644
--- a/src/SAVer.mli
+++ b/src/SAVer.mli
@@ -28,11 +28,9 @@ type answer = {
   nb_proved : int;  (** Number of data points verifying the property. *)
 }
 
-val call :
+val call_prover :
   Call_provers.resource_limit ->
   Whyconf.main ->
-  Env.env ->
   Whyconf.config_prover ->
-  dataset:string ->
-  Task.task ->
+  (Language.svm_shape, string) Dataset.predicate ->
   answer
diff --git a/src/dataset.ml b/src/dataset.ml
index f0bd338..3443b41 100644
--- a/src/dataset.ml
+++ b/src/dataset.ml
@@ -24,16 +24,14 @@ open Why3
 open Base
 
 type eps = Constant.constant
-type predicate = Correct | Robust of eps | CondRobust of eps
-type 'a task = { model : 'a; predicate : predicate }
+type kind = Correct | Robust of eps | CondRobust of eps
+type ('a, 'b) predicate = { model : 'a; dataset : 'b; kind : kind }
 
 let string_of_eps eps = Fmt.str "%a" Constant.print_def eps
 
 let find_predicate_ls env p =
-  let dataset_th =
-    Pmodule.read_module env [ "caisar" ] "DataSetClassification"
-  in
-  Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ]
+  let th = Pmodule.read_module env [ "caisar" ] "DataSetClassification" in
+  Theory.ns_find_ls th.mod_theory.th_export [ p ]
 
 let failwith_unsupported_term t =
   failwith (Fmt.str "Unsupported term in %a" Pretty.print_term t)
@@ -41,11 +39,15 @@ let failwith_unsupported_term t =
 let failwith_unsupported_ls ls =
   failwith (Fmt.str "Unsupported logic symbol %a" Pretty.print_ls ls)
 
-let interpret env lookup task =
+let interpret_predicate env ~on_model ~on_dataset task =
   let term = Task.task_goal_fmla task in
   match term.t_node with
-  | Term.Tapp (ls, { t_node = Tapp (ls_svm_apply, _); _ } :: _dataset :: tt) ->
-    let predicate =
+  | Term.Tapp
+      ( ls,
+        { t_node = Tapp (ls_svm_apply, _); _ } (* model *)
+        :: { t_node = Tapp (dataset, _); _ } (* dataset *)
+        :: tt ) ->
+    let kind =
       match tt with
       | [] ->
         let correct_predicate = find_predicate_ls env "correct" in
@@ -62,12 +64,9 @@ let interpret env lookup task =
         else failwith_unsupported_ls ls
       | _ -> failwith_unsupported_term term
     in
-    let model =
-      match lookup ls_svm_apply with
-      | Some t -> t
-      | None -> invalid_arg "No model found in task"
-    in
-    { model; predicate }
+    let dataset = on_dataset dataset in
+    let model = on_model ls_svm_apply in
+    { model; dataset; kind }
   | _ ->
     (* No other term node is supported. *)
     failwith_unsupported_term term
diff --git a/src/dataset.mli b/src/dataset.mli
index e80c775..c05d438 100644
--- a/src/dataset.mli
+++ b/src/dataset.mli
@@ -23,8 +23,14 @@
 open Why3
 
 type eps
-type predicate = private Correct | Robust of eps | CondRobust of eps
-type 'a task = private { model : 'a; predicate : predicate }
+type kind = private Correct | Robust of eps | CondRobust of eps
+type ('a, 'b) predicate = private { model : 'a; dataset : 'b; kind : kind }
 
 val string_of_eps : eps -> string
-val interpret : Env.env -> (Term.lsymbol -> 'a option) -> Task.task -> 'a task
+
+val interpret_predicate :
+  Env.env ->
+  on_model:(Term.lsymbol -> 'a) ->
+  on_dataset:(Term.lsymbol -> 'b) ->
+  Task.task ->
+  ('a, 'b) predicate
diff --git a/src/language.mli b/src/language.mli
index cd3329a..4d0a1ca 100644
--- a/src/language.mli
+++ b/src/language.mli
@@ -32,10 +32,10 @@ type nn_shape = {
 type svm_shape = { nb_inputs : int; nb_classes : int; filename : string }
 
 val lookup_loaded_nets : Term.lsymbol -> nn_shape option
-(** @return the shape of a nnet Why3 representation. *)
+(** @return the shape of a NN given its Why3 representation. *)
 
 val lookup_loaded_svms : Term.lsymbol -> svm_shape option
-(** @return the shape of a svm Why3 representation. *)
+(** @return the shape of a SVM given its Why3 representation. *)
 
 val register_nnet_support : unit -> unit
 (** Register NNet parser. *)
diff --git a/src/verification.ml b/src/verification.ml
index 85e7b26..8b7fef1 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -74,7 +74,15 @@ let answer_saver limit config env config_prover dataset_csv task =
     | None -> invalid_arg "No dataset provided for SAVer"
     | Some filename -> filename
   in
-  let answer = SAVer.call limit config env config_prover ~dataset task in
+  let dataset_predicate =
+    let on_model ls =
+      let message = Fmt.str "No SVM model found as %a" Pretty.print_ls ls in
+      Option.value_exn ~message (Language.lookup_loaded_svms ls)
+    in
+    let on_dataset _ls = dataset in
+    Dataset.interpret_predicate env ~on_model ~on_dataset task
+  in
+  let answer = SAVer.call_prover limit config config_prover dataset_predicate in
   match answer.prover_answer with
   | Call_provers.Unknown "" ->
     let additional_info = Fmt.str "%d/%d" answer.nb_proved answer.nb_total in
-- 
GitLab