@@ -313,7 +313,6 @@ def optimize(
313313 seed : int = None ,
314314 inits : Union [Dict , float , str ] = None ,
315315 output_dir : str = None ,
316- save_diagnostics : bool = True ,
317316 algorithm : str = None ,
318317 init_alpha : float = None ,
319318 iter : int = None ,
@@ -367,11 +366,6 @@ def optimize(
367366 files are written. If unspecified, output files will be written
368367 to a temporary directory which is deleted upon session exit.
369368
370- :param save_diagnostics: Whether or not to save diagnostics. If True,
371- csv output files are written to an output file with filename
372- template '<model_name>-<YYYYMMDDHHMM>-diagnostic-<chain_id>',
373- e.g. 'bernoulli-201912081451-diagnostic-1.csv'.
374-
375369 :param algorithm: Algorithm to use. One of: 'BFGS', 'LBFGS', 'Newton'
376370
377371 :param init_alpha: Line search step size for first iteration
@@ -393,7 +387,7 @@ def optimize(
393387 seed = seed ,
394388 inits = _inits ,
395389 output_dir = output_dir ,
396- save_diagnostics = save_diagnostics ,
390+ save_diagnostics = False ,
397391 method_args = optimize_args ,
398392 )
399393
@@ -402,12 +396,8 @@ def optimize(
402396 self ._run_cmdstan (runset , dummy_chain_id )
403397
404398 if not runset ._check_retcodes ():
405- msg = 'Error during optimizing'
406- if runset ._retcode (dummy_chain_id ) != 0 :
407- msg = '{}, error code {}' .format (
408- msg , runset ._retcode (dummy_chain_id )
409- )
410- raise RuntimeError (msg )
399+ msg = 'Error during optimization.\n {}' .format (runset .get_err_msgs ())
400+ raise RuntimeError (msg )
411401 mle = CmdStanMLE (runset )
412402 return mle
413403
@@ -750,17 +740,9 @@ def sample(
750740 # re-enable logger for console
751741 self ._logger .propagate = True
752742
753- err_msg = 'Error during sampling.\n '
754743 if not runset ._check_retcodes ():
755- for i in range (chains ):
756- if runset ._retcode (i ) != 0 :
757- err_msg = '{}chain {} returned error code {}\n ' .format (
758- err_msg , i + 1 , runset ._retcode (i )
759- )
760- console_errs = runset ._get_err_msgs ()
761- if len (console_errs ) > 0 :
762- err_msg = '{}{}' .format (err_msg , '' .join (console_errs ))
763- raise RuntimeError (err_msg )
744+ msg = 'Error during sampling.\n {}' .format (runset .get_err_msgs ())
745+ raise RuntimeError (msg )
764746
765747 mcmc = CmdStanMCMC (runset , validate_csv , logger = self ._logger )
766748 return mcmc
@@ -899,12 +881,9 @@ def generate_quantities(
899881 executor .submit (self ._run_cmdstan , runset , i )
900882
901883 if not runset ._check_retcodes ():
902- msg = 'Error during generate_quantities'
903- for i in range (chains ):
904- if runset ._retcode (i ) != 0 :
905- msg = '{}, chain {} returned error code {}' .format (
906- msg , i , runset ._retcode (i )
907- )
884+ msg = 'Error during generate_quantities.\n {}' .format (
885+ runset .get_err_msgs ()
886+ )
908887 raise RuntimeError (msg )
909888 quantities = CmdStanGQ (runset = runset , mcmc_sample = sample_drawset )
910889 return quantities
@@ -1034,12 +1013,10 @@ def variational(
10341013 if not valid :
10351014 raise RuntimeError ('The algorithm may not have converged.' )
10361015 if not runset ._check_retcodes ():
1037- msg = 'Error during variational inference'
1038- if runset ._retcode (dummy_chain_id ) != 0 :
1039- msg = '{}, error code {}' .format (
1040- msg , runset ._retcode (dummy_chain_id )
1041- )
1042- raise RuntimeError (msg )
1016+ msg = 'Error during variational inference.\n {}' .format (
1017+ runset .get_err_msgs ()
1018+ )
1019+ raise RuntimeError (msg )
10431020 # pylint: disable=invalid-name
10441021 vb = CmdStanVB (runset )
10451022 return vb
0 commit comments