Skip to content
Snippets Groups Projects
Commit e4fd53c3 authored by Michele Alberti's avatar Michele Alberti
Browse files

[language] Rework neural networks interface.

parent b6ed30cd
No related branches found
No related tags found
No related merge requests found
......@@ -253,8 +253,8 @@ let caisar_builtins : caisar_env CRE.built_in_theories list =
let filename = Caml.Filename.concat cwd neural_network in
let nn =
match id_string with
| "NNet" -> NNet (Language.create_nnet_nn env filename)
| "ONNX" -> ONNX (Language.create_onnx_nn env filename)
| "NNet" -> NNet (Language.create_nn_nnet env filename)
| "ONNX" -> ONNX (Language.create_nn_onnx env filename)
| _ ->
failwith (Fmt.str "Unrecognized neural network format %s" id_string)
in
......
......@@ -186,8 +186,8 @@ let mem_vector = Term.Hls.mem vectors
(* -- Classifier *)
type nn = {
nn_inputs : int;
nn_outputs : int;
nn_nb_inputs : int;
nn_nb_outputs : int;
nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty]
nn_filename : string;
nn_nier : Onnx.G.t option; [@opaque]
......@@ -204,7 +204,7 @@ let fresh_nn_ls env name =
let id = Ident.id_fresh name in
Term.create_fsymbol id [] ty
let create_nnet_nn =
let create_nn_nnet =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
let ty_elt =
......@@ -219,8 +219,8 @@ let create_nnet_nn =
| Error s -> Loc.errorm "%s" s
| Ok { n_inputs; n_outputs; _ } ->
{
nn_inputs = n_inputs;
nn_outputs = n_outputs;
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_elt;
nn_filename = filename;
nn_nier = None;
......@@ -229,7 +229,7 @@ let create_nnet_nn =
Term.Hls.add nets ls nn;
ls))
let create_onnx_nn =
let create_nn_onnx =
Env.Wenv.memoize 13 (fun env ->
let h = Hashtbl.create (module String) in
let ty_elt = vector_elt_ty env in
......@@ -249,8 +249,8 @@ let create_onnx_nn =
| Ok nier -> Some nier
in
{
nn_inputs = n_inputs;
nn_outputs = n_outputs;
nn_nb_inputs = n_inputs;
nn_nb_outputs = n_outputs;
nn_ty_elt = ty_elt;
nn_filename = filename;
nn_nier = nier;
......
......@@ -72,15 +72,15 @@ val mem_vector : Term.lsymbol -> bool
(** -- Neural Network *)
type nn = private {
nn_inputs : int;
nn_outputs : int;
nn_nb_inputs : int;
nn_nb_outputs : int;
nn_ty_elt : Ty.ty;
nn_filename : string;
nn_nier : Onnx.G.t option;
}
[@@deriving show]
val create_nnet_nn : Env.env -> string -> Term.lsymbol
val create_onnx_nn : Env.env -> string -> Term.lsymbol
val create_nn_nnet : Env.env -> string -> Term.lsymbol
val create_nn_onnx : Env.env -> string -> Term.lsymbol
val lookup_nn : Term.lsymbol -> nn option
val mem_nn : Term.lsymbol -> bool
......@@ -37,8 +37,8 @@ let get_input_variables =
[ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] )
when String.equal ls_name.id_string (Ident.op_infix "@@") -> (
match (Language.lookup_nn ls1, Language.lookup_vector ls2) with
| Some { nn_inputs; _ }, Some n ->
assert (nn_inputs = n && n = List.length args);
| Some { nn_nb_inputs; _ }, Some n ->
assert (nn_nb_inputs = n && n = List.length args);
List.foldi ~init:acc ~f:add args
| _ -> acc)
| _ -> Term.t_fold aux acc term
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment