diff --git a/solver.ml b/solver.ml index 0a609ad3996943ce114e3f0d2578f82615b638be..3c0796d3994bfb83298350dc9799a66b1ee758d2 100644 --- a/solver.ml +++ b/solver.ml @@ -50,28 +50,51 @@ let default_exe_name_of_solver = function | Pyrat -> "pyrat.py" | Marabou -> "Marabou" +let default_option_of_solver = function + | Pyrat -> "--version" + | Marabou -> "--version" + let exe_name_of_solver solver = match Sys.getenv_opt (env_var_of_solver solver) with | None -> default_exe_name_of_solver solver | Some v -> v -let check_availability solver = +let default_exec_of_solver solver = let module Filename = Caml.Filename in + let tmp = Filename.temp_file "caisar" "" in + let cmd = + Filename.quote_command + ~stdout:tmp ~stderr:tmp + (exe_name_of_solver solver) + [default_option_of_solver solver] + in + let retcode = Sys.command cmd in + let in_channel = Stdlib.open_in tmp in + let firstline = Stdlib.input_line in_channel in + Stdlib.close_in in_channel; + Sys.remove tmp; + cmd, retcode, firstline + +let version_of_solver solver s = + let regexp_string = + (* We use same pattern to extract solver version numbers for the moment. *) + match solver with + | Pyrat | Marabou -> + (* We want to match for version string of the form '0.0.1alpha+'. *) + "[0-9]\\(\\.[0-9]\\)*\\(\\.?[A-Za-z-+]\\)*" + in + let regexp_version = Str.regexp regexp_string in + try + ignore (Str.search_forward regexp_version s 0); + Ok (Str.matched_string s) + with Stdlib.Not_found -> + Error "No recognizable version found." + +let check_availability ~err_on_version_mismatch solver = let solver_name = show_solver solver in Logs.info (fun m -> m "Checking availability of `%s'." solver_name); try - let tmp = Filename.temp_file "caisar" "" in - let cmd = - Filename.quote_command - ~stdout:tmp ~stderr:tmp - (exe_name_of_solver solver) - ["--version"] - in - let retcode = Sys.command cmd in - let in_channel = Stdlib.open_in tmp in - let firstline = Stdlib.input_line in_channel in - Stdlib.close_in in_channel; - Sys.remove tmp; + let cmd, retcode, firstline = default_exec_of_solver solver in if retcode <> 0 then Error @@ -81,18 +104,18 @@ let check_availability solver = use variable `%s' to directly provide the path to the executable." cmd (env_var_of_solver solver)) else begin - (try - let regexp_version = - (* We want to match for version string of the form '0.0.1alpha+'. *) - Str.regexp "[0-9]\\(\\.[0-9]\\)*\\(\\.?[A-Za-z-+]\\)*" - in - ignore (Str.search_forward regexp_version firstline 0); - Logs.info - (fun m -> m "Found version `%s'." (Str.matched_string firstline)) - with Stdlib.Not_found -> - Logs.warn - (fun m -> m "Found unrecognizable version of `%s'." solver_name)); - Ok () + match version_of_solver solver firstline with + | Error _ as error -> + if err_on_version_mismatch + then error + else begin + Logs.warn + (fun m -> m "Found unrecognizable version of `%s'." solver_name); + Ok () + end + | Ok version -> + Logs.info (fun m -> m "Found version `%s'." version); + Ok () end with Stdlib.Not_found | End_of_file | Sys_error _ -> Error "Unexpected failure." @@ -102,7 +125,7 @@ let check_compatibility solver model = | (Pyrat | Marabou), (Model.Onnx as f) -> Error (Format.sprintf - "Cannot deal with `%s' and model format `%s'." + "No support yet for `%s' and model format `%s'." (show_solver solver) (Model.show_format f)) | _ -> @@ -159,7 +182,7 @@ let build_command ?raw_options solver property model = let exec ?raw_options solver model property = let open Result in (* Check solver availability in PATH. *) - check_availability solver >>= fun () -> + check_availability ~err_on_version_mismatch:false solver >>= fun () -> (* Check solver and model compatibility. *) check_compatibility solver model >>= fun () -> (* Build the required command-line. *)