Skip to content

Commit 254f9de

Browse files
authored
Merge pull request #280 from stan-dev/feature/22-better-error-handling
Feature/22 better error handling
2 parents d758170 + c6d0c57 commit 254f9de

5 files changed

Lines changed: 68 additions & 49 deletions

File tree

cmdstanpy/model.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def optimize(
313313
seed: int = None,
314314
inits: Union[Dict, float, str] = None,
315315
output_dir: str = None,
316-
save_diagnostics: bool = True,
317316
algorithm: str = None,
318317
init_alpha: float = None,
319318
iter: int = None,
@@ -367,11 +366,6 @@ def optimize(
367366
files are written. If unspecified, output files will be written
368367
to a temporary directory which is deleted upon session exit.
369368
370-
:param save_diagnostics: Whether or not to save diagnostics. If True,
371-
csv output files are written to an output file with filename
372-
template '<model_name>-<YYYYMMDDHHMM>-diagnostic-<chain_id>',
373-
e.g. 'bernoulli-201912081451-diagnostic-1.csv'.
374-
375369
:param algorithm: Algorithm to use. One of: 'BFGS', 'LBFGS', 'Newton'
376370
377371
:param init_alpha: Line search step size for first iteration
@@ -393,7 +387,7 @@ def optimize(
393387
seed=seed,
394388
inits=_inits,
395389
output_dir=output_dir,
396-
save_diagnostics=save_diagnostics,
390+
save_diagnostics=False,
397391
method_args=optimize_args,
398392
)
399393

@@ -402,12 +396,8 @@ def optimize(
402396
self._run_cmdstan(runset, dummy_chain_id)
403397

404398
if not runset._check_retcodes():
405-
msg = 'Error during optimizing'
406-
if runset._retcode(dummy_chain_id) != 0:
407-
msg = '{}, error code {}'.format(
408-
msg, runset._retcode(dummy_chain_id)
409-
)
410-
raise RuntimeError(msg)
399+
msg = 'Error during optimization.\n{}'.format(runset.get_err_msgs())
400+
raise RuntimeError(msg)
411401
mle = CmdStanMLE(runset)
412402
return mle
413403

@@ -750,17 +740,9 @@ def sample(
750740
# re-enable logger for console
751741
self._logger.propagate = True
752742

753-
err_msg = 'Error during sampling.\n'
754743
if not runset._check_retcodes():
755-
for i in range(chains):
756-
if runset._retcode(i) != 0:
757-
err_msg = '{}chain {} returned error code {}\n'.format(
758-
err_msg, i + 1, runset._retcode(i)
759-
)
760-
console_errs = runset._get_err_msgs()
761-
if len(console_errs) > 0:
762-
err_msg = '{}{}'.format(err_msg, ''.join(console_errs))
763-
raise RuntimeError(err_msg)
744+
msg = 'Error during sampling.\n{}'.format(runset.get_err_msgs())
745+
raise RuntimeError(msg)
764746

765747
mcmc = CmdStanMCMC(runset, validate_csv, logger=self._logger)
766748
return mcmc
@@ -899,12 +881,9 @@ def generate_quantities(
899881
executor.submit(self._run_cmdstan, runset, i)
900882

901883
if not runset._check_retcodes():
902-
msg = 'Error during generate_quantities'
903-
for i in range(chains):
904-
if runset._retcode(i) != 0:
905-
msg = '{}, chain {} returned error code {}'.format(
906-
msg, i, runset._retcode(i)
907-
)
884+
msg = 'Error during generate_quantities.\n{}'.format(
885+
runset.get_err_msgs()
886+
)
908887
raise RuntimeError(msg)
909888
quantities = CmdStanGQ(runset=runset, mcmc_sample=sample_drawset)
910889
return quantities
@@ -1034,12 +1013,10 @@ def variational(
10341013
if not valid:
10351014
raise RuntimeError('The algorithm may not have converged.')
10361015
if not runset._check_retcodes():
1037-
msg = 'Error during variational inference'
1038-
if runset._retcode(dummy_chain_id) != 0:
1039-
msg = '{}, error code {}'.format(
1040-
msg, runset._retcode(dummy_chain_id)
1041-
)
1042-
raise RuntimeError(msg)
1016+
msg = 'Error during variational inference.\n{}'.format(
1017+
runset.get_err_msgs()
1018+
)
1019+
raise RuntimeError(msg)
10431020
# pylint: disable=invalid-name
10441021
vb = CmdStanVB(runset)
10451022
return vb

cmdstanpy/stanfit.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,25 @@ def __init__(
126126
def __repr__(self) -> str:
127127
repr = 'RunSet: chains={}'.format(self._chains)
128128
repr = '{}\n cmd:\n\t{}'.format(repr, self._cmds[0])
129-
repr = '{}\n csv_files:\n\t{}\n output_files:\n\t{}'.format(
130-
repr, '\n\t'.join(self._csv_files), '\n\t'.join(self._stdout_files)
131-
)
129+
repr = '{}\n retcodes={}'.format(repr, self._retcodes)
130+
if os.path.exists(self._csv_files[0]):
131+
repr = '{}\n csv_files:\n\t{}'.format(
132+
repr, '\n\t'.join(self._csv_files)
133+
)
134+
if self._args.save_diagnostics and os.path.exists(
135+
self._diagnostic_files[0]
136+
):
137+
repr = '{}\n diagnostics_files:\n\t{}'.format(
138+
repr, '\n\t'.join(self._diagnostic_files)
139+
)
140+
if os.path.exists(self._stdout_files[0]):
141+
repr = '{}\n console_msgs:\n\t{}'.format(
142+
repr, '\n\t'.join(self._stdout_files)
143+
)
144+
if os.path.exists(self._stderr_files[0]):
145+
repr = '{}\n error_msgs:\n\t{}'.format(
146+
repr, '\n\t'.join(self._stderr_files)
147+
)
132148
return repr
133149

134150
@property
@@ -199,27 +215,37 @@ def _set_retcode(self, idx: int, val: int) -> None:
199215
"""Set retcode for chain[idx] to val."""
200216
self._retcodes[idx] = val
201217

202-
def _get_err_msgs(self) -> List[str]:
218+
def get_err_msgs(self) -> List[str]:
203219
"""Checks console messages for each chain."""
204220
msgs = []
221+
msgs.append(self.__repr__())
205222
for i in range(self._chains):
206223
if (
207224
os.path.exists(self._stderr_files[i])
208225
and os.stat(self._stderr_files[i]).st_size > 0
209226
):
210227
with open(self._stderr_files[i], 'r') as fd:
211-
msgs.append('chain {}:\n{}\n'.format(i + 1, fd.read()))
228+
msgs.append(
229+
'chain_id {}:\n{}\n'.format(
230+
self._chain_ids[i], fd.read()
231+
)
232+
)
212233
if (
213234
os.path.exists(self._stdout_files[i])
214235
and os.stat(self._stdout_files[i]).st_size > 0
215236
):
216237
with open(self._stdout_files[i], 'r') as fd:
217238
contents = fd.read()
218-
pat = re.compile(r'^Exception.*$', re.M)
239+
# pattern matches initial "Exception" or "Error" msg
240+
pat = re.compile(r'^E[rx].*$', re.M)
219241
errors = re.findall(pat, contents)
220242
if len(errors) > 0:
221-
msgs.append('chain {}: {}\n'.format(i + 1, errors))
222-
return msgs
243+
msgs.append(
244+
'chain_id {}:\n\t{}\n'.format(
245+
self._chain_ids[i], '\n\t'.join(errors)
246+
)
247+
)
248+
return '\n'.join(msgs)
223249

224250
def save_csvfiles(self, dir: str = None) -> None:
225251
"""

test/test_optimize.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ def test_optimize_bad(self):
124124
)
125125
exp_bound_model = CmdStanModel(stan_file=stan)
126126
no_data = {}
127-
with self.assertRaisesRegex(
128-
Exception, 'Error during optimizing, error code 70'
129-
):
127+
with self.assertRaisesRegex(RuntimeError, 'Error during optimization'):
130128
exp_bound_model.optimize(
131129
data=no_data, seed=1239812093, inits=None, algorithm='BFGS'
132130
)

test/test_runset.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class RunSetTest(unittest.TestCase):
16-
def test_check_retcodes(self):
16+
def test_check_repr(self):
1717
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
1818
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1919
sampler_args = SamplerArgs()
@@ -26,8 +26,26 @@ def test_check_retcodes(self):
2626
method_args=sampler_args,
2727
)
2828
runset = RunSet(args=cmdstan_args)
29+
2930
self.assertIn('RunSet: chains=4', runset.__repr__())
3031
self.assertIn('method=sample', runset.__repr__())
32+
self.assertIn('retcodes=[-1, -1, -1, -1]', runset.__repr__())
33+
self.assertIn('csv_files:', runset.__repr__())
34+
self.assertNotIn('diagnostics_files:', runset.__repr__())
35+
36+
def test_check_retcodes(self):
37+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
38+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
39+
sampler_args = SamplerArgs()
40+
chain_ids = [1, 2, 3, 4] # default
41+
cmdstan_args = CmdStanArgs(
42+
model_name='bernoulli',
43+
model_exe=exe,
44+
chain_ids=chain_ids,
45+
data=jdata,
46+
method_args=sampler_args,
47+
)
48+
runset = RunSet(args=cmdstan_args)
3149

3250
retcodes = runset._retcodes
3351
self.assertEqual(4, len(retcodes))
@@ -60,7 +78,7 @@ def test_get_err_msgs(self):
6078
stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt'
6179
path = os.path.join(DATAFILES_PATH, stdout_file)
6280
runset._stdout_files[i] = path
63-
errs = '\n\t'.join(runset._get_err_msgs())
81+
errs = runset.get_err_msgs()
6482
self.assertIn('Exception', errs)
6583

6684
def test_output_filenames(self):

test/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def test_validate_bad_run(self):
720720
DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-4.txt'
721721
),
722722
]
723-
self.assertEqual(len(runset._get_err_msgs()), 4)
723+
self.assertIn('Exception', runset.get_err_msgs())
724724

725725
# csv file headers inconsistent
726726
runset._csv_files = [

0 commit comments

Comments
 (0)