Skip to content

Commit cde2392

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents 5f4256e + df02c41 commit cde2392

4 files changed

Lines changed: 132 additions & 9 deletions

File tree

cmdstanpy/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""PyPi Version"""
22

3-
__version__ = '0.9.62'
3+
__version__ = '0.9.63'

cmdstanpy/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def sample(
436436
output_dir: str = None,
437437
save_diagnostics: bool = False,
438438
show_progress: Union[bool, str] = False,
439+
validate_csv: bool = True,
439440
) -> CmdStanMCMC:
440441
"""
441442
Run or more chains of the NUTS sampler to produce a set of draws
@@ -580,6 +581,11 @@ def sample(
580581
If show_progress=='notebook' use tqdm_notebook
581582
(needs nodejs for jupyter).
582583
584+
:param validate_csv: If ``False``, skip scan of sample csv output file.
585+
When sample is large or disk i/o is slow, will speed up processing.
586+
Default is ``True`` - sample csv files are scanned for completeness
587+
and consistency.
588+
583589
:return: CmdStanMCMC object
584590
"""
585591
if chains is None:
@@ -620,7 +626,7 @@ def sample(
620626
if parallel_chains is None:
621627
parallel_chains = max(min(cpu_count(), chains), 1)
622628
elif parallel_chains > chains:
623-
self._logger.warning(
629+
self._logger.info(
624630
'Requesting %u parallel_chains for %u chains,'
625631
' running all chains in parallel.',
626632
parallel_chains,
@@ -756,7 +762,7 @@ def sample(
756762
err_msg = '{}{}'.format(err_msg, ''.join(console_errs))
757763
raise RuntimeError(err_msg)
758764

759-
mcmc = CmdStanMCMC(runset)
765+
mcmc = CmdStanMCMC(runset, validate_csv, logger=self._logger)
760766
return mcmc
761767

762768
def generate_quantities(

cmdstanpy/stanfit.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import shutil
66
import copy
77
import logging
8+
import math
89
from typing import List, Tuple, Dict
910
from collections import Counter, OrderedDict
1011
from datetime import datetime
@@ -272,14 +273,21 @@ class CmdStanMCMC:
272273
Container for outputs from CmdStan sampler run.
273274
"""
274275

275-
def __init__(self, runset: RunSet) -> None:
276+
# pylint: disable=too-many-instance-attributes
277+
def __init__(
278+
self,
279+
runset: RunSet,
280+
validate_csv: bool = True,
281+
logger: logging.Logger = None,
282+
) -> None:
276283
"""Initialize object."""
277284
if not runset.method == Method.SAMPLE:
278285
raise ValueError(
279286
'Wrong runset method, expecting sample runset, '
280287
'found method {}'.format(runset.method)
281288
)
282289
self.runset = runset
290+
self._logger = logger or get_logger()
283291
# copy info from runset
284292
self._is_fixed_param = runset._args.method_args.fixed_param
285293
self._iter_sampling = runset._args.method_args.iter_sampling
@@ -298,7 +306,9 @@ def __init__(self, runset: RunSet) -> None:
298306
self._warmup = None
299307
self._drawset = None
300308
self._stan_variable_dims = {}
301-
self._validate_csv_files()
309+
self._validate_csv = validate_csv
310+
if validate_csv:
311+
self.validate_csv_files()
302312

303313
def __repr__(self) -> str:
304314
repr = 'CmdStanMCMC: model={} chains={}{}'.format(
@@ -326,11 +336,15 @@ def chain_ids(self) -> List[int]:
326336
@property
327337
def num_draws(self) -> int:
328338
"""Number of post-warmup draws per chain."""
339+
if not self._validate_csv and self._draws_sampling is None:
340+
return int(math.ceil(self._iter_sampling / self._thin))
329341
return self._draws_sampling
330342

331343
@property
332344
def num_draws_warmup(self) -> int:
333345
"""Number of warmup draws per chain."""
346+
if not self._validate_csv and self._draws_warmup is None:
347+
return int(math.ceil(self._iter_warmup / self._thin))
334348
return self._draws_warmup
335349

336350
@property
@@ -339,6 +353,12 @@ def column_names(self) -> Tuple[str, ...]:
339353
Names of all per-draw outputs: all
340354
sampler and model parameters and quantities of interest
341355
"""
356+
if not self._validate_csv and len(self._column_names) == 0:
357+
self._logger.warning(
358+
'csv files not yet validated, run method validate_csv_files()'
359+
' in order to retrieve sample metadata.'
360+
)
361+
return None
342362
return self._column_names
343363

344364
@property
@@ -348,6 +368,12 @@ def stan_variable_dims(self) -> Dict:
348368
Scalar types have int value '1'. Structured types have list of dims,
349369
e.g., program variable ``vector[10] foo`` has entry ``('foo', [10])``.
350370
"""
371+
if not self._validate_csv and len(self._stan_variable_dims) == 0:
372+
self._logger.warning(
373+
'csv files not yet validated, run method validate_csv_files()'
374+
' in order to retrieve sample metadata.'
375+
)
376+
return None
351377
return copy.deepcopy(self._stan_variable_dims)
352378

353379
@property
@@ -356,6 +382,14 @@ def metric_type(self) -> str:
356382
Metric type used for adaptation, either 'diag_e' or 'dense_e'.
357383
When sampler algorithm 'fixed_param' is specified, metric_type is None.
358384
"""
385+
if self._is_fixed_param:
386+
return None
387+
if not self._validate_csv and self._metric_type is None:
388+
self._logger.warning(
389+
'csv files not yet validated, run method validate_csv_files()'
390+
' in order to retrieve sample metadata.'
391+
)
392+
return None
359393
return self._metric_type
360394

361395
@property
@@ -364,7 +398,15 @@ def metric(self) -> np.ndarray:
364398
Metric used by sampler for each chain.
365399
When sampler algorithm 'fixed_param' is specified, metric is None.
366400
"""
367-
if not self._is_fixed_param and self._metric is None:
401+
if self._is_fixed_param:
402+
return None
403+
if not self._validate_csv and self._metric is None:
404+
self._logger.warning(
405+
'csv files not yet validated, run method validate_csv_files()'
406+
' in order to retrieve sample metadata.'
407+
)
408+
return None
409+
if self._sample is None:
368410
self._assemble_sample()
369411
return self._metric
370412

@@ -374,7 +416,15 @@ def stepsize(self) -> np.ndarray:
374416
Stepsize used by sampler for each chain.
375417
When sampler algorithm 'fixed_param' is specified, stepsize is None.
376418
"""
377-
if not self._is_fixed_param and self._stepsize is None:
419+
if self._is_fixed_param:
420+
return None
421+
if not self._validate_csv and self._stepsize is None:
422+
self._logger.warning(
423+
'csv files not yet validated, run method validate_csv_files()'
424+
' in order to retrieve sample metadata.'
425+
)
426+
return None
427+
if self._sample is None:
378428
self._assemble_sample()
379429
return self._stepsize
380430

@@ -386,6 +436,8 @@ def sample(self) -> np.ndarray:
386436
so that the values for each parameter are stored contiguously
387437
in memory, likewise all draws from a chain are contiguous.
388438
"""
439+
if not self._validate_csv and self._sample is None:
440+
self.validate_csv_files()
389441
if self._sample is None:
390442
self._assemble_sample()
391443
return self._sample
@@ -400,11 +452,13 @@ def warmup(self) -> np.ndarray:
400452
"""
401453
if not self._save_warmup:
402454
return None
455+
if not self._validate_csv and self._sample is None:
456+
self.validate_csv_files()
403457
if self._sample is None:
404458
self._assemble_sample()
405459
return self._warmup
406460

407-
def _validate_csv_files(self) -> None:
461+
def validate_csv_files(self) -> None:
408462
"""
409463
Checks that csv output files for all chains are consistent.
410464
Populates attributes for draws, column_names, num_params, metric_type.

test/test_sample.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_multi_proc(self):
258258
log.check_present(
259259
(
260260
'cmdstanpy',
261-
'WARNING',
261+
'INFO',
262262
'Requesting 7 parallel_chains for 1 chains, '
263263
'running all chains in parallel.',
264264
)
@@ -292,6 +292,9 @@ def test_fixed_param_good(self):
292292
data=no_data, seed=12345, iter_sampling=100, fixed_param=True
293293
)
294294
self.assertEqual(datagen_fit.runset._args.method, Method.SAMPLE)
295+
self.assertEqual(datagen_fit.metric_type, None)
296+
self.assertEqual(datagen_fit.metric, None)
297+
self.assertEqual(datagen_fit.stepsize, None)
295298

296299
for i in range(datagen_fit.runset.chains):
297300
csv_file = datagen_fit.runset.csv_files[i]
@@ -851,6 +854,66 @@ def test_variables(self):
851854
self.assertTrue('theta' in vars)
852855
self.assertEqual(vars['theta'].shape, (20, 4))
853856

857+
def test_validate(self):
858+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
859+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
860+
bern_model = CmdStanModel(stan_file=stan)
861+
bern_fit = bern_model.sample(
862+
data=jdata,
863+
chains=2,
864+
seed=12345,
865+
iter_warmup=200,
866+
iter_sampling=100,
867+
thin=2,
868+
save_warmup=True,
869+
validate_csv=False,
870+
)
871+
# check error messages
872+
with LogCapture() as log:
873+
logging.getLogger()
874+
self.assertIsNone(bern_fit.column_names)
875+
expect = 'csv files not yet validated'
876+
msg = log.actual()[-1][-1]
877+
self.assertTrue(msg.startswith(expect))
878+
879+
with LogCapture() as log:
880+
logging.getLogger()
881+
self.assertIsNone(bern_fit.stan_variable_dims)
882+
expect = 'csv files not yet validated'
883+
msg = log.actual()[-1][-1]
884+
self.assertTrue(msg.startswith(expect))
885+
886+
with LogCapture() as log:
887+
logging.getLogger()
888+
self.assertIsNone(bern_fit.metric_type)
889+
expect = 'csv files not yet validated'
890+
msg = log.actual()[-1][-1]
891+
self.assertTrue(msg.startswith(expect))
892+
893+
with LogCapture() as log:
894+
logging.getLogger()
895+
self.assertIsNone(bern_fit.metric)
896+
expect = 'csv files not yet validated'
897+
msg = log.actual()[-1][-1]
898+
self.assertTrue(msg.startswith(expect))
899+
900+
with LogCapture() as log:
901+
logging.getLogger()
902+
self.assertIsNone(bern_fit.stepsize)
903+
expect = 'csv files not yet validated'
904+
msg = log.actual()[-1][-1]
905+
self.assertTrue(msg.startswith(expect))
906+
907+
# check computations match
908+
self.assertEqual(bern_fit.num_draws, 50)
909+
self.assertEqual(bern_fit.num_draws_warmup, 100)
910+
bern_fit.validate_csv_files()
911+
self.assertEqual(bern_fit.num_draws, 50)
912+
self.assertEqual(bern_fit.num_draws_warmup, 100)
913+
self.assertEqual(len(bern_fit.column_names), 8)
914+
self.assertEqual(len(bern_fit.stan_variable_dims), 1)
915+
self.assertEqual(bern_fit.metric_type, 'diag_e')
916+
854917

855918
if __name__ == '__main__':
856919
unittest.main()

0 commit comments

Comments
 (0)