Skip to content

Commit 981d7f3

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents cde2392 + d758170 commit 981d7f3

5 files changed

Lines changed: 80 additions & 7 deletions

File tree

cmdstanpy/install_cmdstan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def validate_dir(install_dir):
163163
if not os.path.exists(install_dir):
164164
try:
165165
os.makedirs(install_dir)
166-
except OSError as e:
166+
except (IOError, OSError, PermissionError) as e:
167167
raise ValueError(
168168
'Cannot create directory: {}'.format(install_dir)
169169
) from e
@@ -176,7 +176,7 @@ def validate_dir(install_dir):
176176
with open('tmp_test_w', 'w'):
177177
pass
178178
os.remove('tmp_test_w') # cleanup
179-
except OSError as e:
179+
except (IOError, OSError, PermissionError) as e:
180180
raise ValueError(
181181
'Cannot write files to directory {}'.format(install_dir)
182182
) from e

cmdstanpy/install_cxx_toolchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def validate_dir(install_dir):
211211
if not os.path.exists(install_dir):
212212
try:
213213
os.makedirs(install_dir)
214-
except OSError as e:
214+
except (IOError, OSError, PermissionError) as e:
215215
raise ValueError(
216216
'Cannot create directory: {}'.format(install_dir)
217217
) from e

cmdstanpy/stanfit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def save_csvfiles(self, dir: str = None) -> None:
235235
with open(test_path, 'w'):
236236
pass
237237
os.remove(test_path) # cleanup
238-
except OSError:
238+
except (IOError, OSError, PermissionError):
239239
raise Exception('cannot save to path: {}'.format(dir))
240240

241241
for i in range(self.chains):

test/test_runset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import unittest
55

6+
from cmdstanpy import _TMPDIR
67
from cmdstanpy.cmdstan_args import SamplerArgs, CmdStanArgs
78
from cmdstanpy.utils import EXTENSION
89
from cmdstanpy.stanfit import RunSet
@@ -95,6 +96,34 @@ def test_commands(self):
9596
self.assertIn('id=1', runset._cmds[0])
9697
self.assertIn('id=4', runset._cmds[3])
9798

99+
def test_save_diagnostics(self):
100+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
101+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
102+
sampler_args = SamplerArgs()
103+
chain_ids = [1, 2, 3, 4]
104+
cmdstan_args = CmdStanArgs(
105+
model_name='bernoulli',
106+
model_exe=exe,
107+
chain_ids=chain_ids,
108+
data=jdata,
109+
method_args=sampler_args,
110+
save_diagnostics=True,
111+
)
112+
runset = RunSet(args=cmdstan_args)
113+
self.assertIn(_TMPDIR, runset.diagnostic_files[0])
114+
115+
cmdstan_args = CmdStanArgs(
116+
model_name='bernoulli',
117+
model_exe=exe,
118+
chain_ids=chain_ids,
119+
data=jdata,
120+
method_args=sampler_args,
121+
save_diagnostics=True,
122+
output_dir=os.path.abspath('.'),
123+
)
124+
runset = RunSet(args=cmdstan_args)
125+
self.assertIn(os.path.abspath('.'), runset.diagnostic_files[0])
126+
98127
def test_chain_ids(self):
99128
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
100129
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

test/test_sample.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
import logging
66
import shutil
77
from multiprocessing import cpu_count
8+
import tempfile
9+
import stat
810
import unittest
911
from time import time
1012
from testfixtures import LogCapture
1113
import pytest
1214

15+
from cmdstanpy import _TMPDIR
1316
from cmdstanpy.cmdstan_args import Method, SamplerArgs, CmdStanArgs
1417
from cmdstanpy.utils import EXTENSION
1518
from cmdstanpy.stanfit import RunSet, CmdStanMCMC
@@ -76,6 +79,36 @@ def test_bernoulli_good(self, stanfile='bernoulli.stan'):
7679
self.assertEqual(bern_fit.stepsize.shape, (2,))
7780
self.assertEqual(bern_fit.metric.shape, (2, 1))
7881

82+
bern_fit = bern_model.sample(
83+
data=jdata,
84+
chains=2,
85+
parallel_chains=2,
86+
seed=12345,
87+
iter_warmup=1000,
88+
iter_sampling=100,
89+
metric='dense_e',
90+
)
91+
self.assertIn('CmdStanMCMC: model=bernoulli', bern_fit.__repr__())
92+
self.assertIn('method=sample', bern_fit.__repr__())
93+
94+
self.assertEqual(bern_fit.runset._args.method, Method.SAMPLE)
95+
96+
for i in range(bern_fit.runset.chains):
97+
csv_file = bern_fit.runset.csv_files[i]
98+
stdout_file = bern_fit.runset.stdout_files[i]
99+
self.assertTrue(os.path.exists(csv_file))
100+
self.assertTrue(os.path.exists(stdout_file))
101+
102+
self.assertEqual(bern_fit.runset.chains, 2)
103+
self.assertEqual(bern_fit.num_draws, 100)
104+
self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS))
105+
106+
bern_sample = bern_fit.sample
107+
self.assertEqual(bern_sample.shape, (100, 2, len(BERNOULLI_COLS)))
108+
self.assertEqual(bern_fit.metric_type, 'dense_e')
109+
self.assertEqual(bern_fit.stepsize.shape, (2,))
110+
self.assertEqual(bern_fit.metric.shape, (2, 1, 1))
111+
79112
bern_fit = bern_model.sample(
80113
data=jdata,
81114
chains=2,
@@ -510,9 +543,9 @@ def test_validate_big_run(self):
510543
self.assertEqual((2000, 3), mo_phis.shape)
511544
phi2095 = fit.get_drawset(params=['phi.2095'])
512545
self.assertEqual((2000, 1), phi2095.shape)
513-
with self.assertRaises(Exception):
546+
with self.assertRaisesRegex(ValueError, 'unknown parameter: phi.2096'):
514547
fit.get_drawset(params=['phi.2096'])
515-
with self.assertRaises(Exception):
548+
with self.assertRaisesRegex(ValueError, 'unknown parameter: ph'):
516549
fit.get_drawset(params=['ph'])
517550

518551
# pylint: disable=no-self-use
@@ -575,7 +608,9 @@ def test_save_csv(self):
575608
for i in range(bern_fit.runset.chains):
576609
csv_file = bern_fit.runset.csv_files[i]
577610
self.assertTrue(os.path.exists(csv_file))
578-
with self.assertRaisesRegex(Exception, 'file exists'):
611+
with self.assertRaisesRegex(
612+
ValueError, 'file exists, not overwriting: '
613+
):
579614
bern_fit.save_csvfiles(dir=DATAFILES_PATH)
580615

581616
tmp2_dir = os.path.join(HERE, 'tmp2')
@@ -611,6 +646,15 @@ def test_save_csv(self):
611646
if os.path.exists(bern_fit.runset.stderr_files[i]):
612647
os.remove(bern_fit.runset.stderr_files[i])
613648

649+
with self.assertRaisesRegex(ValueError, 'cannot access csv file'):
650+
bern_fit.save_csvfiles(dir=DATAFILES_PATH)
651+
652+
if platform.system() != "Windows":
653+
with self.assertRaisesRegex(Exception, 'cannot save to path: '):
654+
dir = tempfile.mkdtemp(dir=_TMPDIR)
655+
os.chmod(dir, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
656+
bern_fit.save_csvfiles(dir=dir)
657+
614658
def test_diagnose_divergences(self):
615659
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
616660
sampler_args = SamplerArgs()

0 commit comments

Comments
 (0)