From ec002e5fe67bbad5c3c60f8b85891213e111382e Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Tue, 11 Apr 2023 13:14:39 +0200 Subject: [PATCH] [interpretation] Use classifier logic symbols created by language module. --- src/interpretation.ml | 5 ++++- tests/interpretation_acasxu.t | 29 +++++++++++++------------ tests/interpretation_dataset.t | 39 ++++++++++++++++++++-------------- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/interpretation.ml b/src/interpretation.ml index b5d97f8..2fa1b82 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -75,7 +75,10 @@ let ls_of_caisar_op engine op ty_args ty = Hashtbl.find_or_add caisar_env.ls_of_caisar_op op ~default:(fun () -> let id = Ident.id_fresh "caisar_op" in let ls = - match op with Vector v -> v | _ -> Term.create_lsymbol id ty_args ty + match op with + | Classifier (NNet c | ONNX c) -> c + | Vector v -> v + | _ -> Term.create_lsymbol id ty_args ty in (* Fmt.pr "ls: %a@." Pretty.print_ls ls; *) Hashtbl.Poly.add_exn caisar_env.ls_of_caisar_op ~key:op ~data:ls; diff --git a/tests/interpretation_acasxu.t b/tests/interpretation_acasxu.t index 9d1adea..93fae2a 100644 --- a/tests/interpretation_acasxu.t +++ b/tests/interpretation_acasxu.t @@ -126,17 +126,19 @@ Test interpret on acasxu le (add RNE (mul RNE - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] (373.9499200000000200816430151462554931640625:t)) (7.51888402010059753166615337249822914600372314453125:t)) (1500.0:t) - vector, (Interpretation.Vector 5) - caisar_op, + nnet_classifier, (Interpretation.Classifier NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; nn_filename = "$TESTCASE_ROOT/TestNetwork.nnet"; nn_nier = <opaque> }) + vector, + (Interpretation.Vector 5) P2 : forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. (le (0.0:t) (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) /\ le (add RNE (mul RNE caisar_v (60261.0:t)) (19791.0:t)) (60760.0:t) -> @@ -187,37 +189,38 @@ Test interpret on acasxu le (900.0:t) (add RNE (mul RNE caisar_v3 (1100.0:t)) (650.0:t)) -> le (960.0:t) (add RNE (mul RNE caisar_v4 (1200.0:t)) (600.0:t))) -> not (((lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2]) /\ lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [3]) /\ lt - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op1 + (nnet_classifier %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [4]) - vector, (Interpretation.Vector 5) - caisar_op1, + nnet_classifier, (Interpretation.Classifier NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; nn_filename = "$TESTCASE_ROOT/TestNetwork.nnet"; nn_nier = <opaque> }) + vector, + (Interpretation.Vector 5) diff --git a/tests/interpretation_dataset.t b/tests/interpretation_dataset.t index 3ff6693..de5f086 100644 --- a/tests/interpretation_dataset.t +++ b/tests/interpretation_dataset.t @@ -78,14 +78,18 @@ Test interpret on dataset (0.776470588000000017103729987866245210170745849609375:t)) (0.375:t) -> lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] /\ lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1]) /\ (forall caisar_v:t, caisar_v1:t, caisar_v2:t, caisar_v3:t, caisar_v4:t. ((((le (0.0:t) caisar_v /\ le caisar_v (1.0:t)) /\ @@ -119,29 +123,32 @@ Test interpret on dataset (0.78431372499999996161790249971090815961360931396484375:t)) (0.375:t) -> lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [1] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0] /\ lt - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [2] - (caisar_op %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) + (nnet_classifier + %% vector caisar_v caisar_v1 caisar_v2 caisar_v3 caisar_v4) [0]) - caisar_op1, - (Interpretation.Data - (Interpretation.D_csv - ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) - caisar_op2, + caisar_op, (Interpretation.Data (Interpretation.D_csv ["1.0"; "0.0"; "0.019607843"; "0.776470588"; "0.784313725"])) vector, (Interpretation.Vector 5) - caisar_op, + nnet_classifier, (Interpretation.Classifier NNet: { Language.nn_inputs = 5; nn_outputs = 5; nn_ty_elt = t; nn_filename = "$TESTCASE_ROOT/TestNetwork.nnet"; nn_nier = <opaque> }) - caisar_op3, - (Interpretation.Dataset <csv>) + caisar_op1, (Interpretation.Dataset <csv>) + caisar_op2, + (Interpretation.Data + (Interpretation.D_csv + ["0.0"; "1.0"; "0.784313725"; "0.019607843"; "0.776470588"])) -- GitLab