Skip to content
Snippets Groups Projects
Commit 687dd42c authored by Michele Alberti's avatar Michele Alberti
Browse files

[interpretation] Remove notion of tensor from code and library.

parent 2269f8ea
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
module CRE = Reduction_engine (* Caisar Reduction Engine *) module CRE = Reduction_engine (* CAISAR Reduction Engine *)
open Why3 open Why3
open Base open Base
...@@ -43,7 +43,6 @@ type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"] ...@@ -43,7 +43,6 @@ type dataset = DS_csv of Csv.t [@printer fun fmt _ -> Fmt.pf fmt "<csv>"]
[@@deriving show] [@@deriving show]
type data = D_csv of string list [@@deriving show] type data = D_csv of string list [@@deriving show]
type index = I_csv of int [@@deriving show]
type vector = type vector =
(Language.vector (Language.vector
...@@ -56,9 +55,7 @@ type caisar_op = ...@@ -56,9 +55,7 @@ type caisar_op =
| Classifier of classifier | Classifier of classifier
| Dataset of dataset | Dataset of dataset
| Data of data | Data of data
| Index of index
| Vector of vector | Vector of vector
| Tensor of int
[@@deriving show] [@@deriving show]
type caisar_env = { type caisar_env = {
...@@ -163,7 +160,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -163,7 +160,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
let n = Option.value_exn (Language.lookup_vector v) in let n = Option.value_exn (Language.lookup_vector v) in
assert (List.length tl1 = n && i <= n); assert (List.length tl1 = n && i <= n);
term (List.nth_exn tl1 i) term (List.nth_exn tl1 i)
| Data _ | Classifier _ | Tensor _ | Index _ -> assert false) | Data _ | Classifier _ -> assert false)
| [ Term t1; Term t2 ] -> | [ Term t1; Term t2 ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t1; t2 ]) term (Term.t_app_infer ls [ t1; t2 ])
...@@ -181,14 +178,14 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -181,14 +178,14 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
| Vector v -> | Vector v ->
int (BigInt.of_int (Option.value_exn (Language.lookup_vector v))) int (BigInt.of_int (Option.value_exn (Language.lookup_vector v)))
| Data (D_csv data) -> int (BigInt.of_int (List.length data)) | Data (D_csv data) -> int (BigInt.of_int (List.length data))
| Classifier _ | Tensor _ | Index _ -> assert false) | Classifier _ -> assert false)
| [ Term { t_node = Tapp (ls, tl); _ } ] -> ( | [ Term { t_node = Tapp (ls, tl); _ } ] -> (
match caisar_op_of_ls engine ls with match caisar_op_of_ls engine ls with
| Vector v -> | Vector v ->
let n = Option.value_exn (Language.lookup_vector v) in let n = Option.value_exn (Language.lookup_vector v) in
assert (List.length tl = n); assert (List.length tl = n);
int (BigInt.of_int n) int (BigInt.of_int n)
| Dataset _ | Data _ | Classifier _ | Tensor _ | Index _ -> assert false) | Dataset _ | Data _ | Classifier _ -> assert false)
| [ Term t ] -> | [ Term t ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t ]) term (Term.t_app_infer ls [ t ])
...@@ -264,71 +261,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -264,71 +261,7 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
in in
Eval (term_of_caisar_op ~args engine caisar_op ty) Eval (term_of_caisar_op ~args engine caisar_op ty)
| Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv)) | Dataset (DS_csv csv) -> int (BigInt.of_int (Csv.lines csv))
| Data _ | Classifier _ | Tensor _ | Index _ -> assert false) | Data _ | Classifier _ -> assert false)
| [ Term t1; Term t2 ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t1; t2 ])
| _ -> invalid_arg (error_message ls)
in
(* Tensor *)
let tget : _ CRE.builtin =
fun engine ls vl _ty ->
(* Fmt.pr "--@.tget: ls:%a , ty:%a@." Pretty.print_ls ls *)
(* Fmt.(option ~none:nop Pretty.print_ty) *)
(* ty; *)
match vl with
| [
Term ({ t_node = Tapp (ls1, tl1); _ } as _t1);
Term ({ t_node = Tapp (ls2, _); _ } as _t2);
] -> (
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with
| Tensor n, Index (I_csv i) ->
assert (i <= n);
term (List.nth_exn tl1 i)
| _ -> assert false)
| [ Term t1; Term t2 ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t1; t2 ])
| _ -> invalid_arg (error_message ls)
in
let tminus : _ CRE.builtin =
fun engine ls vl ty ->
(* Fmt.pr "--@.tminus: ls:%a , ty:%a@." Pretty.print_ls ls *)
(* Fmt.(option ~none:nop Pretty.print_ty) *)
(* ty; *)
match vl with
| [
Term ({ t_node = Tapp (ls1, tl1); _ } as _t1);
Term ({ t_node = Tapp (ls2, _); _ } as _t2);
] -> (
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
match (caisar_op_of_ls engine ls1, caisar_op_of_ls engine ls2) with
| Tensor n, Data (D_csv data) ->
assert (n = List.length data);
let ty_cst =
match ty with
| Some { ty_node = Tyapp (_, [ ty ]); _ } -> ty
| _ -> assert false
in
let csts =
List.map data ~f:(fun d ->
let cst = const_real_of_float (Float.of_string d) in
Term.t_const cst ty_cst)
in
let minus =
(* TODO: generalize wrt the type of constants [csts]. *)
let { env; _ } = CRE.user_env engine in
let th = Env.read_theory env [ "ieee_float" ] "Float64" in
Theory.(ns_find_ls th.th_export [ Ident.op_infix ".-" ])
in
let args =
List.map2_exn tl1 csts ~f:(fun tl c ->
(Term.t_app_infer minus [ tl; c ], ty_cst))
in
term (term_of_caisar_op ~args engine (Tensor n) ty)
| _ -> assert false)
| [ Term t1; Term t2 ] -> | [ Term t1; Term t2 ] ->
(* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *) (* Fmt.pr "Terms: %a , %a@." Pretty.print_term t1 Pretty.print_term t2; *)
term (Term.t_app_infer ls [ t1; t2 ]) term (Term.t_app_infer ls [ t1; t2 ])
...@@ -402,19 +335,11 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -402,19 +335,11 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
([ "length" ], None, length); ([ "length" ], None, length);
([ "L"; "mapi" ], None, mapi); ([ "L"; "mapi" ], None, mapi);
] ); ] );
( [ "interpretation" ],
"Tensor",
[],
[
([ Ident.op_infix "#" ], None, tget);
([ Ident.op_infix "-" ], None, tminus);
] );
( [ "interpretation" ], ( [ "interpretation" ],
"Classifier", "Classifier",
[], [],
[ [
([ "read_classifier" ], None, read_classifier); ([ "read_classifier" ], None, read_classifier);
([ Ident.op_infix "@@" ], None, apply_classifier);
([ Ident.op_infix "%%" ], None, apply_classifier); ([ Ident.op_infix "%%" ], None, apply_classifier);
] ); ] );
( [ "interpretation" ], ( [ "interpretation" ],
...@@ -425,37 +350,6 @@ let builtin_caisar : caisar_env CRE.built_in_theories list = ...@@ -425,37 +350,6 @@ let builtin_caisar : caisar_env CRE.built_in_theories list =
let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = let bounded_quant engine vs ~cond : CRE.bounded_quant_result option =
match cond.Term.t_node with match cond.Term.t_node with
| Tapp
( { ls_name = { id_string = "equal_shape"; _ }; _ },
[
({ t_node = Tapp (ls, _); _ } as _t1);
({ t_node = Tvar vs2; _ } as _t2);
] ) ->
(* Fmt.pr "--@.equal_shape: %a %a@." Pretty.print_term t1 Pretty.print_term
t2; *)
if not (Term.vs_equal vs vs2)
then None
else
let n =
match caisar_op_of_ls engine ls with
| Data (D_csv d) -> List.length d
| _ -> assert false
in
let ty =
match vs.vs_ty with
| { ty_node = Tyapp (_, ty :: _); _ } -> ty
| _ -> assert false
in
let new_quant =
List.init n ~f:(fun _ ->
let preid = Ident.id_fresh "caisar_t" in
Term.create_vsymbol preid ty)
in
let args = List.map new_quant ~f:(fun vs -> (Term.t_var vs, ty)) in
let substitutions =
[ term_of_caisar_op ~args engine (Tensor n) (Some vs.vs_ty) ]
in
Some { new_quant; substitutions }
| Tapp | Tapp
( { ls_name = { id_string = "has_length"; _ }; _ }, ( { ls_name = { id_string = "has_length"; _ }; _ },
[ [
...@@ -487,25 +381,6 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option = ...@@ -487,25 +381,6 @@ let bounded_quant engine vs ~cond : CRE.bounded_quant_result option =
[ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ] [ term_of_caisar_op ~args engine caisar_op (Some vs.vs_ty) ]
in in
Some { new_quant; substitutions } Some { new_quant; substitutions }
| Tapp
( { ls_name = { id_string = "valid_index"; _ }; _ },
[
({ t_node = Tapp (ls, _); _ } as _t1);
({ t_node = Tvar vs2; _ } as _t2);
] ) -> (
if not (Term.vs_equal vs vs2)
then None
else
match caisar_op_of_ls engine ls with
| Tensor n ->
let new_quant = [] in
let substitutions =
List.init n ~f:(fun i ->
term_of_caisar_op engine (Index (I_csv i)) (Some vs.vs_ty))
in
Some { new_quant; substitutions }
| _ -> assert false
| exception _ -> None)
| _ -> None | _ -> None
let interpret_task ~cwd env task = let interpret_task ~cwd env task =
......
...@@ -52,35 +52,18 @@ theory Vector ...@@ -52,35 +52,18 @@ theory Vector
end end
end end
theory Tensor
use int.Int
use Vector
type tensor 'a
type index = vector int
function (#) (t: tensor 'a) (v: vector int) : 'a
function (-) (t1: tensor 'a) (t2: tensor 'a) : tensor 'a
predicate equal_shape (t1: tensor 'a) (t2: tensor 'b)
predicate valid_index (t: tensor 'a) (v: index)
end
theory Classifier theory Classifier
use Vector use Vector
use Tensor
type classifier type classifier
type kind = ONNX | NNet | OVO type kind = ONNX | NNet | OVO
function read_classifier (f: string) (k: kind) : classifier function read_classifier (f: string) (k: kind) : classifier
function (@@) (c: classifier) (t: tensor 'a) : vector 'a
function (%%) (c: classifier) (v: vector 'a) : vector 'a function (%%) (c: classifier) (v: vector 'a) : vector 'a
end end
theory Dataset theory Dataset
use Vector use Vector
use Tensor
type dataset 'a 'b = vector ('a, 'b) type dataset 'a 'b = vector ('a, 'b)
type format = CSV type format = CSV
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment