Skip to content
Snippets Groups Projects
Commit b0002177 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin Committed by Michele Alberti
Browse files

Basic structure for ONNX output through command line.

parent 8c2c18fa
No related branches found
No related tags found
No related merge requests found
...@@ -504,6 +504,46 @@ let nier_of_onnx_protoc (model : Oprotom.t) = ...@@ -504,6 +504,46 @@ let nier_of_onnx_protoc (model : Oprotom.t) =
| Some g -> produce_cfg g | Some g -> produce_cfg g
| None -> raise (ParseError "No graph in ONNX input file found") | None -> raise (ParseError "No graph in ONNX input file found")
let nier_to_onnx_protoc nier =
(* TODO: write a simple ONNX model from a dummy NIER *)
let vertices = G.vertex_list nier in
let protocs =
let vertex_to_protoc v =
let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in
let name = NCFG.Node.get_name v in
let domain = "" in
let input, output =
(NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v)
in
Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain
~attribute:[] ~doc_string:"" ()
in
List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices
in
let protog =
Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[]
~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[]
~quantization_annotation:[] ()
in
let protom =
Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[]
~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:""
~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[]
~training_info:[] ~functions:[] ()
in
let writer = Oprotom.to_proto protom in
Ocaml_protoc_plugin.Writer.contents writer
let write_nier_to_onnx _nier out_channel =
let nier = G.init_cfg in
let n =
Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP
~op_p:None ~pred:[] ~succ:[] ~tensor:None
in
G.add_vertex nier n;
let onnx = nier_to_onnx_protoc nier in
Stdio.Out_channel.output_string out_channel onnx
let parse_in_channel in_channel = let parse_in_channel in_channel =
let open Result in let open Result in
try try
...@@ -528,3 +568,8 @@ let parse filename = ...@@ -528,3 +568,8 @@ let parse filename =
Fun.protect Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel) ~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel in_channel) (fun () -> parse_in_channel in_channel)
let write nier filename =
let out_chan = Stdlib.open_out filename in
write_nier_to_onnx nier out_chan;
Stdlib.close_out out_chan
...@@ -30,4 +30,7 @@ type t = private { ...@@ -30,4 +30,7 @@ type t = private {
(** ONNX model metadata and intermediate representation. *) (** ONNX model metadata and intermediate representation. *)
val parse : string -> (t, string) Result.t val parse : string -> (t, string) Result.t
(** Parse an ONNX file. *) (** Parse an ONNX file into a NIER. *)
val write : G.t -> string -> unit
(** Write a NIER into an ONNX file. *)
...@@ -244,7 +244,8 @@ let create_nn_onnx = ...@@ -244,7 +244,8 @@ let create_nn_onnx =
Logs.warn (fun m -> Logs.warn (fun m ->
m "Cannot build network intermediate representation:@ %s" msg); m "Cannot build network intermediate representation:@ %s" msg);
None None
| Ok nier -> Some nier | Ok nier ->
Some nier
in in
{ {
nn_nb_inputs = n_inputs; nn_nb_inputs = n_inputs;
......
...@@ -123,14 +123,14 @@ let log_theory_answer = ...@@ -123,14 +123,14 @@ let log_theory_answer =
additional_info))) additional_info)))
let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
?def_constants ?theories ?goals files = ?def_constants ?theories ?goals ?onnx_out_file files =
let memlimit = Option.map memlimit ~f:memlimit_of_string in let memlimit = Option.map memlimit ~f:memlimit_of_string in
let timelimit = Option.map timelimit ~f:timelimit_of_string in let timelimit = Option.map timelimit ~f:timelimit_of_string in
let theory_answers = let theory_answers =
List.map files List.map files
~f: ~f:
(Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset (Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset
prover ?prover_altern ?def_constants ?theories ?goals) prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_file)
in in
List.iter theory_answers ~f:log_theory_answer; List.iter theory_answers ~f:log_theory_answer;
theory_answers theory_answers
...@@ -256,6 +256,10 @@ let verify_cmd = ...@@ -256,6 +256,10 @@ let verify_cmd =
let doc = "Dataset $(docv) (CSV format only)." in let doc = "Dataset $(docv) (CSV format only)." in
Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE") Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE")
in in
let onnx_out_file =
let doc = "Path where to save the ONNX outputs from NIER." in
Arg.(value & opt (some string) None & info [ "onnx-out-file" ] ~doc)
in
let define_constants = let define_constants =
let doc = "Define a declared constant with the given value." in let doc = "Define a declared constant with the given value." in
Arg.( Arg.(
...@@ -295,16 +299,17 @@ let verify_cmd = ...@@ -295,16 +299,17 @@ let verify_cmd =
in in
let verify_term = let verify_term =
let verify format loadpath memlimit timelimit prover prover_altern dataset let verify format loadpath memlimit timelimit prover prover_altern dataset
def_constants theories goals files () = def_constants theories goals onnx_out_file files () =
ignore ignore
(verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover (verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover
?prover_altern ~def_constants ~theories ~goals files) ?prover_altern ~def_constants ~theories ~goals ?onnx_out_file files)
in in
Term.( Term.(
const (fun _ -> exec_cmd cmdname) const (fun _ -> exec_cmd cmdname)
$ setup_logs $ setup_logs
$ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover $ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover
$ prover_altern $ dataset $ define_constants $ theories $ goals $ files)) $ prover_altern $ dataset $ define_constants $ theories $ goals
$ onnx_out_file $ files))
in in
Cmd.v info verify_term Cmd.v info verify_term
......
...@@ -243,8 +243,8 @@ let answer_dataset limit config env prover config_prover driver dataset task = ...@@ -243,8 +243,8 @@ let answer_dataset limit config env prover config_prover driver dataset task =
in in
(prover_answer, additional_info) (prover_answer, additional_info)
let answer_generic limit config prover config_prover driver ~proof_strategy task let answer_generic limit config prover config_prover driver ~proof_strategy
= onnx_out_file task =
let tasks = proof_strategy task in let tasks = proof_strategy task in
let answers = let answers =
List.concat_map tasks ~f:(fun task -> List.concat_map tasks ~f:(fun task ->
...@@ -255,6 +255,33 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task ...@@ -255,6 +255,33 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
| Some _ -> assert false (* By construction of the meta. *) | Some _ -> assert false (* By construction of the meta. *)
| None -> invalid_arg "No neural network model found in task" | None -> invalid_arg "No neural network model found in task"
in in
let () =
let save_onnx_in ls f =
match Language.lookup_nn ls with
| Some { nn_nier = Some g; _ } -> (
try
Onnx.write g f;
Logs.info (fun m -> m "@[Wrote ONNX file at '%s'@]" f)
with Sys_error msg ->
Logs.err (fun m ->
m "@[System error: tried to write ONNX file a '%s', got '%s'@]"
f msg))
| None -> ()
| _ -> ()
in
match onnx_out_file with
| Some f ->
Task.task_iter
(fun decl ->
match decl.td_node with
| Use _ | Clone _ | Meta _ -> ()
| Decl decl -> (
match decl.d_node with
| Dparam ls -> save_onnx_in ls f
| _ -> ()))
task
| _ -> ()
in
let tasks = let tasks =
(* Turn [task] into a list (ie, conjunction) of disjunctions of (* Turn [task] into a list (ie, conjunction) of disjunctions of
tasks. *) tasks. *)
...@@ -271,7 +298,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task ...@@ -271,7 +298,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
(prover_answer, additional_info) (prover_answer, additional_info)
let call_prover ~cwd ~limit config env prover config_prover driver ?dataset let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
def_constants task = def_constants onnx_out_file task =
let prover_answer, additional_info = let prover_answer, additional_info =
match prover with match prover with
| Prover.Saver -> answer_saver limit config env config_prover dataset task | Prover.Saver -> answer_saver limit config env config_prover dataset task
...@@ -283,12 +310,13 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset ...@@ -283,12 +310,13 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
let task = Interpretation.interpret_task ~cwd env ~def_constants task in let task = Interpretation.interpret_task ~cwd env ~def_constants task in
let proof_strategy = Proof_strategy.apply_native_nn_prover in let proof_strategy = Proof_strategy.apply_native_nn_prover in
answer_generic limit config prover config_prover driver ~proof_strategy answer_generic limit config prover config_prover driver ~proof_strategy
task onnx_out_file task
| CVC5 -> | CVC5 ->
let task = Interpretation.interpret_task ~cwd env ~def_constants task in let task = Interpretation.interpret_task ~cwd env ~def_constants task in
let proof_strategy = Proof_strategy.apply_classic_prover env in let proof_strategy = Proof_strategy.apply_classic_prover env in
(* Not outputting ONNX for CVC5 as it does not handle such format*)
answer_generic limit config prover config_prover driver ~proof_strategy answer_generic limit config prover config_prover driver ~proof_strategy
task None task
in in
let id = Task.task_goal task in let id = Task.task_goal task in
{ id; prover_answer; additional_info } { id; prover_answer; additional_info }
...@@ -328,7 +356,7 @@ let tasks_of_theory ~theories ~goals theory = ...@@ -328,7 +356,7 @@ let tasks_of_theory ~theories ~goals theory =
List.exists goals_of_theory ~f:(String.equal task_goal_id)) List.exists goals_of_theory ~f:(String.equal task_goal_id))
let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
?(def_constants = []) ?(theories = []) ?(goals = []) file = ?(def_constants = []) ?(theories = []) ?(goals = []) ?onnx_out_file file =
let debug = Logging.(is_debug_level src_prover_call) in let debug = Logging.(is_debug_level src_prover_call) in
(if debug then Debug.(set_flag (lookup_flag "call_prover"))); (if debug then Debug.(set_flag (lookup_flag "call_prover")));
let env, config = create_env loadpath in let env, config = create_env loadpath in
...@@ -391,6 +419,6 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern ...@@ -391,6 +419,6 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
List.map List.map
~f: ~f:
(call_prover ~cwd ~limit main env prover config_prover driver ?dataset (call_prover ~cwd ~limit main env prover config_prover driver ?dataset
def_constants) def_constants onnx_out_file)
tasks) tasks)
mstr_theory mstr_theory
...@@ -55,6 +55,7 @@ val verify : ...@@ -55,6 +55,7 @@ val verify :
?def_constants:(string * string) list -> ?def_constants:(string * string) list ->
?theories:string list -> ?theories:string list ->
?goals:(string * string list) list -> ?goals:(string * string list) list ->
?onnx_out_file: string ->
File.t -> File.t ->
verification_result verification_result
(** Starts a verification of the given [file] with the provided [prover]. (** Starts a verification of the given [file] with the provided [prover].
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment