From 9718c499e8bfe4e95c36f3c50d032a7fa1cb6091 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Thu, 15 Sep 2022 17:21:06 +0200
Subject: [PATCH] [SAVer] Extend SAVer support for correct and conditional
 robust predicates.

---
 lib/ovo/ovo.ml      |  69 +++++++++---------
 src/saver.ml        | 168 ++++++++++++++++++++++++++++++++++++++++++++
 src/saver.mli       |  38 ++++++++++
 src/verification.ml | 104 ++++-----------------------
 stdlib/caisar.mlw   |  32 +++++++--
 tests/bin/saver     |   2 +-
 tests/simple_ovo.t  |   8 ++-
 7 files changed, 287 insertions(+), 134 deletions(-)
 create mode 100644 src/saver.ml
 create mode 100644 src/saver.mli

diff --git a/lib/ovo/ovo.ml b/lib/ovo/ovo.ml
index fcc63082..f05331d3 100644
--- a/lib/ovo/ovo.ml
+++ b/lib/ovo/ovo.ml
@@ -40,51 +40,48 @@ let handle_ovo_line ~f in_channel =
     ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
     (Csv.next in_channel)
 
-(* Skip the header part, ie comments, of the OVO format. *)
-let skip_ovo_header filename in_channel =
-  let exception End_of_header in
-  let pos_in = ref (Stdlib.pos_in in_channel) in
-  try
-    while true do
-      let line = Stdlib.input_line in_channel in
-      if not (Str.string_match (Str.regexp "//") line 0)
-      then raise End_of_header
-      else pos_in := Stdlib.pos_in in_channel
-    done;
-    assert false
-  with
-  | End_of_header ->
-    (* At this point we have read one line past the header part: seek back. *)
-    Stdlib.seek_in in_channel !pos_in;
-    Ok ()
-  | End_of_file ->
-    Error (Format.sprintf "OVO model not found in file `%s'." filename)
+(* Handle ovo first line: either 'ovo' or 'ovo x y' with 'x' and 'y' positive
+   integer numbers. *)
+let handle_ovo_first_line in_channel =
+  let ovo_format_error_on_first_line = ovo_format_error "first line" in
+  match Csv.next in_channel with
+  | [ first_line ] -> (
+    let in_channel = Csv.of_string ~separator:' ' first_line in
+    match Csv.next in_channel with
+    | [ "ovo"; n_is; n_os ] ->
+      let n_is = Int.of_string (String.strip n_is) in
+      let n_os = Int.of_string (String.strip n_os) in
+      Ok (Some (n_is, n_os))
+    | [ "ovo" ] -> Ok None
+    | _ -> ovo_format_error_on_first_line
+    | exception End_of_file -> ovo_format_error_on_first_line)
+  | _ -> ovo_format_error_on_first_line
+  | exception End_of_file -> ovo_format_error_on_first_line
 
 (* Retrieve inputs and outputs size. *)
-let handle_ovo_basic_info in_channel =
+let handle_ovo_basic_info ~descr in_channel =
   match handle_ovo_line ~f:Int.of_string in_channel with
   | [ dim ] -> Ok dim
-  | _ -> ovo_format_error "first"
-  | exception End_of_file -> ovo_format_error "first"
-
-(* Skip unused flag. *)
-let handle_ovo_unused_flag in_channel =
-  try
-    let _ = Csv.next in_channel in
-    Ok ()
-  with End_of_file -> ovo_format_error "second"
+  | _ -> ovo_format_error descr
+  | exception End_of_file -> ovo_format_error descr
 
 (* Retrieves [filename] OVO model metadata and weights wrt OVO format
    specification, which is described here:
-   https://github.com/abstract-machine-learning/saver#classifier-format. *)
-let parse_in_channel filename in_channel =
+   https://github.com/abstract-machine-learning/saver#classifier-format. In
+   practice, the first line may specify the input/output size values or not. In
+   the latter case, we search for input/output size values in the second/third
+   lines respectively. *)
+let parse_in_channel in_channel =
   let open Result in
   try
-    skip_ovo_header filename in_channel >>= fun () ->
     let in_channel = Csv.of_channel in_channel in
-    handle_ovo_unused_flag in_channel >>= fun () ->
-    handle_ovo_basic_info in_channel >>= fun n_is ->
-    handle_ovo_basic_info in_channel >>= fun n_os ->
+    (handle_ovo_first_line in_channel >>= function
+     | Some (n_is, n_os) -> Ok (n_is, n_os)
+     | None ->
+       handle_ovo_basic_info ~descr:"input size" in_channel >>= fun n_is ->
+       handle_ovo_basic_info ~descr:"output size" in_channel >>= fun n_os ->
+       Ok (n_is, n_os))
+    >>= fun (n_is, n_os) ->
     Csv.close_in in_channel;
     Ok { n_inputs = n_is; n_outputs = n_os }
   with
@@ -96,4 +93,4 @@ let parse filename =
   let in_channel = Stdlib.open_in filename in
   Fun.protect
     ~finally:(fun () -> Stdlib.close_in in_channel)
-    (fun () -> parse_in_channel filename in_channel)
+    (fun () -> parse_in_channel in_channel)
diff --git a/src/saver.ml b/src/saver.ml
new file mode 100644
index 00000000..2c563a46
--- /dev/null
+++ b/src/saver.ml
@@ -0,0 +1,168 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2022                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Why3
+open Base
+
+type predicate_kind = Correct | Robust | CondRobust
+
+type interpreted_task = {
+  svm_filename : string;
+  predicate_kind : predicate_kind;
+  eps : string;
+}
+
+let find_predicate_ls env p =
+  let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in
+  Theory.ns_find_ls dataset_th.mod_theory.th_export [ p ]
+
+let interpret_predicate env ls =
+  let correct_predicate = find_predicate_ls env "correct" in
+  let robust_predicate = find_predicate_ls env "robust" in
+  let cond_robust_predicate = find_predicate_ls env "cond_robust" in
+  if Term.ls_equal ls correct_predicate
+  then Correct
+  else if Term.ls_equal ls robust_predicate
+  then Robust
+  else if Term.ls_equal ls cond_robust_predicate
+  then CondRobust
+  else failwith (Fmt.str "Unsupported by SAVer: %a" Pretty.print_ls ls)
+
+let interpret_task env task =
+  let goal = Task.task_goal_fmla task in
+  match goal.t_node with
+  | Term.Tapp
+      ( ls,
+        [
+          { t_node = Tapp (ls_svm_apply, _); _ };
+          _dataset;
+          { t_node = Tconst e; _ };
+        ] ) ->
+    let predicate_kind = interpret_predicate env ls in
+    let eps = Fmt.str "%a" Constant.print_def e in
+    let svm_filename =
+      match Language.lookup_loaded_svms ls_svm_apply with
+      | Some t -> Unix.realpath t.filename
+      | None -> invalid_arg "No SVM model found in task"
+    in
+    { svm_filename; predicate_kind; eps }
+  | _ ->
+    (* No other term is supported. *)
+    failwith "Unsupported term by SAVer"
+
+let svm = Re__Core.(compile (str "%{svm}"))
+let dataset = Re__Core.(compile (str "%{dataset}"))
+let epsilon = Re__Core.(compile (str "%{epsilon}"))
+let abstraction = Re__Core.(compile (str "%{abstraction}"))
+let distance = Re__Core.(compile (str "%{distance}"))
+
+let build_command config_prover dataset_filename interpreted_task =
+  let command = Whyconf.get_complete_command ~with_steps:false config_prover in
+  let params =
+    [
+      (svm, interpreted_task.svm_filename);
+      (dataset, Unix.realpath dataset_filename);
+      (epsilon, interpreted_task.eps);
+      (distance, "l_inf");
+      (abstraction, "hybrid");
+    ]
+  in
+  List.fold params ~init:command ~f:(fun cmd (param, by) ->
+    Re__Core.replace_string param ~by cmd)
+
+type answer = {
+  prover_answer : Call_provers.prover_answer;
+  nb_total : int;
+  nb_proved : int;
+}
+
+(* SAVer output is made of 7 columns separated by (multiple) space(s) of the
+   following form:
+
+   [SUMMARY] integer float float integer integer integer
+
+   We are interested in recovering the integers. The following regexp matches
+   groups of one or more digits by means of '(\\d+)'. The latter is used 4 times
+   for the 4 columns of integers we are interested in. The 1st group reports the
+   number of total data points in the dataset, 2nd group reports the number of
+   correct data points, 3rd group reports the number of robust data points, 4th
+   group reports the number of conditionally robust (ie, correct and robust)
+   data points. *)
+let re_saver_output =
+  Re__Pcre.regexp
+    "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*(\\d)+\\s*(\\d+)\\s*(\\d+)"
+
+(* The dataset size is the first group of digits matched by
+   [re_saver_output]. *)
+let re_group_number_dataset_size = 1
+
+(* Regexp group number as matched by [re_saver_output] for each kind of
+   predicate. *)
+let re_group_number_of_predicate_kind predicate_kind =
+  match predicate_kind with Correct -> 2 | Robust -> 3 | CondRobust -> 4
+
+let build_answer pred_kind prover_result =
+  match prover_result.Call_provers.pr_answer with
+  | Call_provers.HighFailure -> (
+    let pr_output = prover_result.pr_output in
+    match Re__Core.exec_opt re_saver_output pr_output with
+    | Some re_group ->
+      let nb_total =
+        Int.of_string (Re__Core.Group.get re_group re_group_number_dataset_size)
+      in
+      let nb_proved =
+        let re_group_number = re_group_number_of_predicate_kind pred_kind in
+        Int.of_string (Re__Core.Group.get re_group re_group_number)
+      in
+      let prover_answer =
+        if nb_total = nb_proved
+        then Call_provers.Valid
+        else Call_provers.Invalid
+      in
+      { prover_answer; nb_total; nb_proved }
+    | None -> failwith "Cannot interpret the output provided by SAVer")
+  | _ ->
+    (* Any other answer than HighFailure should never happen as we do not define
+       anything in SAVer's driver. *)
+    assert false
+
+let call limit config env config_prover ~dataset task =
+  let interpreted_task = interpret_task env task in
+  let command = build_command config_prover dataset interpreted_task in
+  let prover_call =
+    let res_parser =
+      {
+        Call_provers.prp_regexps =
+          [ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
+        prp_timeregexps = [];
+        prp_stepregexps = [];
+        prp_exitcodes = [];
+        prp_model_parser = Model_parser.lookup_model_parser "no_model";
+      }
+    in
+    Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
+      ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
+      ~filename:" " ~get_counterexmp:false ~gen_new_file:false
+      ~printing_info:Printer.default_printing_info (Buffer.create 10)
+  in
+  let prover_result = Call_provers.wait_on_call prover_call in
+  build_answer interpreted_task.predicate_kind prover_result
diff --git a/src/saver.mli b/src/saver.mli
new file mode 100644
index 00000000..2bec2d23
--- /dev/null
+++ b/src/saver.mli
@@ -0,0 +1,38 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of CAISAR.                                          *)
+(*                                                                        *)
+(*  Copyright (C) 2022                                                    *)
+(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
+(*         alternatives)                                                  *)
+(*                                                                        *)
+(*  You can redistribute it and/or modify it under the terms of the GNU   *)
+(*  Lesser General Public License as published by the Free Software       *)
+(*  Foundation, version 2.1.                                              *)
+(*                                                                        *)
+(*  It is distributed in the hope that it will be useful,                 *)
+(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
+(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
+(*  GNU Lesser General Public License for more details.                   *)
+(*                                                                        *)
+(*  See the GNU Lesser General Public License version 2.1                 *)
+(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Why3
+
+type answer = {
+  prover_answer : Call_provers.prover_answer;
+  nb_total : int; (* Total number of data points. *)
+  nb_proved : int; (* Number of data points verifying the property. *)
+}
+
+val call :
+  Call_provers.resource_limit ->
+  Whyconf.main ->
+  Env.env ->
+  Whyconf.config_prover ->
+  dataset:string ->
+  Task.task ->
+  answer
diff --git a/src/verification.ml b/src/verification.ml
index c455308f..2d12c9dd 100644
--- a/src/verification.ml
+++ b/src/verification.ml
@@ -61,11 +61,6 @@ let create_env ?(debug = false) loadpath =
     config )
 
 let nnet_or_onnx = Re__Core.(compile (str "%{nnet-onnx}"))
-let svm = Re__Core.(compile (str "%{svm}"))
-let dataset = Re__Core.(compile (str "%{dataset}"))
-let epsilon = Re__Core.(compile (str "%{epsilon}"))
-let abstraction = Re__Core.(compile (str "%{abstraction}"))
-let distance = Re__Core.(compile (str "%{distance}"))
 
 let combine_prover_answers answers =
   List.fold_left answers ~init:Call_provers.Valid ~f:(fun acc r ->
@@ -74,86 +69,15 @@ let combine_prover_answers answers =
     | _ -> acc)
 
 let answer_saver limit config env config_prover dataset_csv task =
-  let handle_task_saver task env dataset_csv command =
-    let dataset_filename =
-      match dataset_csv with
-      | None -> invalid_arg "No dataset provided for SAVer"
-      | Some s -> s
-    in
-    let robust_predicate =
-      let dataset_th = Pmodule.read_module env [ "caisar" ] "DataSet" in
-      Theory.ns_find_ls dataset_th.mod_theory.th_export [ "robust" ]
-    in
-    let goal = Task.task_goal_fmla task in
-    let eps, svm_filename =
-      match goal.t_node with
-      | Term.Tapp
-          ( ls,
-            [
-              { t_node = Tapp (ls_svm_apply, _); _ };
-              _dataset;
-              { t_node = Tconst e; _ };
-            ] ) ->
-        if Term.ls_equal ls robust_predicate
-        then
-          let eps = Fmt.str "%a" Constant.print_def e in
-          let svm_filename =
-            match Language.lookup_loaded_svms ls_svm_apply with
-            | Some t -> t.filename
-            | None -> invalid_arg "No SVM model found in task"
-          in
-          (eps, svm_filename)
-        else failwith "Wrong predicate found"
-      | _ ->
-        (* no other predicate than robust_to is supported *)
-        failwith "Unsupported predicate by SAVer"
-    in
-    let svm_file = Unix.realpath svm_filename in
-    let dataset_file = Unix.realpath dataset_filename in
-    let command = Re__Core.replace_string svm ~by:svm_file command in
-    let command = Re__Core.replace_string dataset ~by:dataset_file command in
-    let command = Re__Core.replace_string epsilon ~by:eps command in
-    let command = Re__Core.replace_string distance ~by:"l_inf" command in
-    let command = Re__Core.replace_string abstraction ~by:"hybrid" command in
-    command
-  in
-  let command = Whyconf.get_complete_command ~with_steps:false config_prover in
-  let command = handle_task_saver task env dataset_csv command in
-  let res_parser =
-    {
-      Call_provers.prp_regexps =
-        [ ("NeverMatch", Call_provers.Failure "Should not happen in CAISAR") ];
-      prp_timeregexps = [];
-      prp_stepregexps = [];
-      prp_exitcodes = [];
-      prp_model_parser = Model_parser.lookup_model_parser "no_model";
-    }
+  let dataset =
+    match dataset_csv with
+    | None -> invalid_arg "No dataset provided for SAVer"
+    | Some filename -> filename
   in
-  let prover_call =
-    Call_provers.call_on_buffer ~libdir:(Whyconf.libdir config)
-      ~datadir:(Whyconf.datadir config) ~command ~limit ~res_parser
-      ~filename:" " ~get_counterexmp:false ~gen_new_file:false
-      ~printing_info:Printer.default_printing_info (Buffer.create 10)
-  in
-  let prover_result = Call_provers.wait_on_call prover_call in
-  match prover_result.pr_answer with
-  | Call_provers.HighFailure -> (
-    let pr_output = prover_result.pr_output in
-    let matcher =
-      Re__Pcre.regexp
-        "\\[SUMMARY\\]\\s*(\\d+)\\s*[0-9.]+\\s*[0-9.]+\\s*\\d+\\s*(\\d+)\\s*\\d"
-    in
-    match Re__Core.exec_opt matcher pr_output with
-    | Some g ->
-      if Int.of_string (Re__Core.Group.get g 1)
-         = Int.of_string (Re__Core.Group.get g 2)
-      then Call_provers.Valid
-      else Call_provers.Invalid
-    | None -> Call_provers.HighFailure)
-  | _ ->
-    (* Any other answer than HighFailure should never happen as we do not define
-       anything in SAVer's driver. *)
-    assert false
+  let answer = Saver.call limit config env config_prover ~dataset task in
+  let prover_answer = answer.prover_answer in
+  let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in
+  (prover_answer, Some additional_info)
 
 let answer_generic limit config prover config_prover driver task =
   let task_prepared = Driver.prepare_task driver task in
@@ -182,11 +106,11 @@ let answer_generic limit config prover config_prover driver task =
     prover_result.pr_answer
   in
   let answers = List.map tasks ~f:call_prover_on_task in
-  let answer = combine_prover_answers answers in
-  answer
+  let prover_answer = combine_prover_answers answers in
+  (prover_answer, None)
 
 let call_prover ~limit config env prover config_prover driver dataset_csv task =
-  let prover_answer =
+  let prover_answer, additional_info =
     match prover with
     | Prover.Saver ->
       answer_saver limit config env config_prover dataset_csv task
@@ -194,8 +118,10 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task =
       answer_generic limit config prover config_prover driver task
   in
   Logs.app (fun m ->
-    m "@[Goal %a:@ %a@]" Pretty.print_pr (Task.task_goal task)
-      Call_provers.print_prover_answer prover_answer)
+    m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task)
+      Call_provers.print_prover_answer prover_answer
+      Fmt.(option ~none:nop (any " " ++ string))
+      additional_info)
 
 let verify ?(debug = false) format loadpath ?memlimit ?timeout prover
   ?dataset_csv file =
diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw
index 675b1ab9..7d02870e 100644
--- a/stdlib/caisar.mlw
+++ b/stdlib/caisar.mlw
@@ -57,15 +57,33 @@ theory DataSet
 
   constant dataset: dataset
 
-  predicate linfty_distance (n: int) (a: features) (b: features) (eps: t) =
-    forall i: int. 0 <= i < n -> .- eps .< a[i] .- b[i] .< eps
+  predicate linfty_distance (a: features) (b: features) (eps: t) =
+    a.length = b.length /\
+    forall i: int. 0 <= i < a.length -> .- eps .< a[i] .- b[i] .< eps
+
+  predicate correct_at (m: model) (d: datum) (eps: t) =
+    forall x': features.
+      let (x, y) = d in
+      linfty_distance x x' eps ->
+      y = predict m x'
+
+  predicate robust_at (m: model) (d: datum) (eps: t) =
+    forall x': features.
+      let (x, _) = d in
+      linfty_distance x x' eps ->
+      predict m x = predict m x'
+
+  predicate cond_robust_at (m: model) (d: datum) (eps: t) =
+    correct_at m d eps /\ robust_at m d eps
+
+  predicate correct (m: model) (d: dataset) (eps: t) =
+    forall i: int. 0 <= i < d.data.length -> correct_at m d.data[i] eps
 
   predicate robust (m: model) (d: dataset) (eps: t) =
-    forall i: int. 0 <= i < d.data.length ->
-      forall x': features.
-        let (x, _) = d.data[i] in
-        linfty_distance d.nb_features x x' eps ->
-        predict m x = predict m x'
+    forall i: int. 0 <= i < d.data.length -> robust_at m d.data[i] eps
+
+  predicate cond_robust (m: model) (d: dataset) (eps: t) =
+    correct m d eps /\ robust m d eps
 end
 
 theory SVM
diff --git a/tests/bin/saver b/tests/bin/saver
index afa49d80..012a62a7 100644
--- a/tests/bin/saver
+++ b/tests/bin/saver
@@ -9,5 +9,5 @@ case $1 in
         echo "SVM: $1"
         echo "Goal:"
         cat $2
-        echo "Unknown"
+        echo "[SUMMARY] 2 8 0.017000 1 2 0"
 esac
diff --git a/tests/simple_ovo.t b/tests/simple_ovo.t
index 1e02da73..cd851827 100644
--- a/tests/simple_ovo.t
+++ b/tests/simple_ovo.t
@@ -29,9 +29,15 @@ Test verify
   > theory T
   >   use TestSVM.SVMasArray
   >   use ieee_float.Float64
+  >   use int.Int
+  >   use array.Array
   >   use caisar.DataSet
   > 
   >   goal G: robust svm_apply dataset (8.0:t)
+  >   goal H: correct svm_apply dataset (8.0:t)
+  >   goal I: cond_robust svm_apply dataset (8.0:t)
   > end
   > EOF
-  [caisar] Goal G: High failure
+  [caisar] Goal G: Valid (2/2)
+  [caisar] Goal H: Invalid (1/2)
+  [caisar] Goal I: Invalid (0/2)
-- 
GitLab