|
5 | 5 | import logging |
6 | 6 | import shutil |
7 | 7 | from multiprocessing import cpu_count |
| 8 | +import tempfile |
| 9 | +import stat |
8 | 10 | import unittest |
9 | 11 | from time import time |
10 | 12 | from testfixtures import LogCapture |
11 | 13 | import pytest |
12 | 14 |
|
| 15 | +from cmdstanpy import _TMPDIR |
13 | 16 | from cmdstanpy.cmdstan_args import Method, SamplerArgs, CmdStanArgs |
14 | 17 | from cmdstanpy.utils import EXTENSION |
15 | 18 | from cmdstanpy.stanfit import RunSet, CmdStanMCMC |
@@ -76,6 +79,36 @@ def test_bernoulli_good(self, stanfile='bernoulli.stan'): |
76 | 79 | self.assertEqual(bern_fit.stepsize.shape, (2,)) |
77 | 80 | self.assertEqual(bern_fit.metric.shape, (2, 1)) |
78 | 81 |
|
| 82 | + bern_fit = bern_model.sample( |
| 83 | + data=jdata, |
| 84 | + chains=2, |
| 85 | + parallel_chains=2, |
| 86 | + seed=12345, |
| 87 | + iter_warmup=1000, |
| 88 | + iter_sampling=100, |
| 89 | + metric='dense_e', |
| 90 | + ) |
| 91 | + self.assertIn('CmdStanMCMC: model=bernoulli', bern_fit.__repr__()) |
| 92 | + self.assertIn('method=sample', bern_fit.__repr__()) |
| 93 | + |
| 94 | + self.assertEqual(bern_fit.runset._args.method, Method.SAMPLE) |
| 95 | + |
| 96 | + for i in range(bern_fit.runset.chains): |
| 97 | + csv_file = bern_fit.runset.csv_files[i] |
| 98 | + stdout_file = bern_fit.runset.stdout_files[i] |
| 99 | + self.assertTrue(os.path.exists(csv_file)) |
| 100 | + self.assertTrue(os.path.exists(stdout_file)) |
| 101 | + |
| 102 | + self.assertEqual(bern_fit.runset.chains, 2) |
| 103 | + self.assertEqual(bern_fit.num_draws, 100) |
| 104 | + self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS)) |
| 105 | + |
| 106 | + bern_sample = bern_fit.sample |
| 107 | + self.assertEqual(bern_sample.shape, (100, 2, len(BERNOULLI_COLS))) |
| 108 | + self.assertEqual(bern_fit.metric_type, 'dense_e') |
| 109 | + self.assertEqual(bern_fit.stepsize.shape, (2,)) |
| 110 | + self.assertEqual(bern_fit.metric.shape, (2, 1, 1)) |
| 111 | + |
79 | 112 | bern_fit = bern_model.sample( |
80 | 113 | data=jdata, |
81 | 114 | chains=2, |
@@ -510,9 +543,9 @@ def test_validate_big_run(self): |
510 | 543 | self.assertEqual((2000, 3), mo_phis.shape) |
511 | 544 | phi2095 = fit.get_drawset(params=['phi.2095']) |
512 | 545 | self.assertEqual((2000, 1), phi2095.shape) |
513 | | - with self.assertRaises(Exception): |
| 546 | + with self.assertRaisesRegex(ValueError, 'unknown parameter: phi.2096'): |
514 | 547 | fit.get_drawset(params=['phi.2096']) |
515 | | - with self.assertRaises(Exception): |
| 548 | + with self.assertRaisesRegex(ValueError, 'unknown parameter: ph'): |
516 | 549 | fit.get_drawset(params=['ph']) |
517 | 550 |
|
518 | 551 | # pylint: disable=no-self-use |
@@ -575,7 +608,9 @@ def test_save_csv(self): |
575 | 608 | for i in range(bern_fit.runset.chains): |
576 | 609 | csv_file = bern_fit.runset.csv_files[i] |
577 | 610 | self.assertTrue(os.path.exists(csv_file)) |
578 | | - with self.assertRaisesRegex(Exception, 'file exists'): |
| 611 | + with self.assertRaisesRegex( |
| 612 | + ValueError, 'file exists, not overwriting: ' |
| 613 | + ): |
579 | 614 | bern_fit.save_csvfiles(dir=DATAFILES_PATH) |
580 | 615 |
|
581 | 616 | tmp2_dir = os.path.join(HERE, 'tmp2') |
@@ -611,6 +646,15 @@ def test_save_csv(self): |
611 | 646 | if os.path.exists(bern_fit.runset.stderr_files[i]): |
612 | 647 | os.remove(bern_fit.runset.stderr_files[i]) |
613 | 648 |
|
| 649 | + with self.assertRaisesRegex(ValueError, 'cannot access csv file'): |
| 650 | + bern_fit.save_csvfiles(dir=DATAFILES_PATH) |
| 651 | + |
| 652 | + if platform.system() != "Windows": |
| 653 | + with self.assertRaisesRegex(Exception, 'cannot save to path: '): |
| 654 | + dir = tempfile.mkdtemp(dir=_TMPDIR) |
| 655 | + os.chmod(dir, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) |
| 656 | + bern_fit.save_csvfiles(dir=dir) |
| 657 | + |
614 | 658 | def test_diagnose_divergences(self): |
615 | 659 | exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) |
616 | 660 | sampler_args = SamplerArgs() |
|
0 commit comments