55import shutil
66import copy
77import logging
8+ import math
89from typing import List , Tuple , Dict
910from collections import Counter , OrderedDict
1011from datetime import datetime
@@ -272,14 +273,21 @@ class CmdStanMCMC:
272273 Container for outputs from CmdStan sampler run.
273274 """
274275
275- def __init__ (self , runset : RunSet ) -> None :
276+ # pylint: disable=too-many-instance-attributes
277+ def __init__ (
278+ self ,
279+ runset : RunSet ,
280+ validate_csv : bool = True ,
281+ logger : logging .Logger = None ,
282+ ) -> None :
276283 """Initialize object."""
277284 if not runset .method == Method .SAMPLE :
278285 raise ValueError (
279286 'Wrong runset method, expecting sample runset, '
280287 'found method {}' .format (runset .method )
281288 )
282289 self .runset = runset
290+ self ._logger = logger or get_logger ()
283291 # copy info from runset
284292 self ._is_fixed_param = runset ._args .method_args .fixed_param
285293 self ._iter_sampling = runset ._args .method_args .iter_sampling
@@ -298,7 +306,9 @@ def __init__(self, runset: RunSet) -> None:
298306 self ._warmup = None
299307 self ._drawset = None
300308 self ._stan_variable_dims = {}
301- self ._validate_csv_files ()
309+ self ._validate_csv = validate_csv
310+ if validate_csv :
311+ self .validate_csv_files ()
302312
303313 def __repr__ (self ) -> str :
304314 repr = 'CmdStanMCMC: model={} chains={}{}' .format (
@@ -326,11 +336,15 @@ def chain_ids(self) -> List[int]:
326336 @property
327337 def num_draws (self ) -> int :
328338 """Number of post-warmup draws per chain."""
339+ if not self ._validate_csv and self ._draws_sampling is None :
340+ return int (math .ceil (self ._iter_sampling / self ._thin ))
329341 return self ._draws_sampling
330342
331343 @property
332344 def num_draws_warmup (self ) -> int :
333345 """Number of warmup draws per chain."""
346+ if not self ._validate_csv and self ._draws_warmup is None :
347+ return int (math .ceil (self ._iter_warmup / self ._thin ))
334348 return self ._draws_warmup
335349
336350 @property
@@ -339,6 +353,12 @@ def column_names(self) -> Tuple[str, ...]:
339353 Names of all per-draw outputs: all
340354 sampler and model parameters and quantities of interest
341355 """
356+ if not self ._validate_csv and len (self ._column_names ) == 0 :
357+ self ._logger .warning (
358+ 'csv files not yet validated, run method validate_csv_files()'
359+ ' in order to retrieve sample metadata.'
360+ )
361+ return None
342362 return self ._column_names
343363
344364 @property
@@ -348,6 +368,12 @@ def stan_variable_dims(self) -> Dict:
348368 Scalar types have int value '1'. Structured types have list of dims,
349369 e.g., program variable ``vector[10] foo`` has entry ``('foo', [10])``.
350370 """
371+ if not self ._validate_csv and len (self ._stan_variable_dims ) == 0 :
372+ self ._logger .warning (
373+ 'csv files not yet validated, run method validate_csv_files()'
374+ ' in order to retrieve sample metadata.'
375+ )
376+ return None
351377 return copy .deepcopy (self ._stan_variable_dims )
352378
353379 @property
@@ -356,6 +382,14 @@ def metric_type(self) -> str:
356382 Metric type used for adaptation, either 'diag_e' or 'dense_e'.
357383 When sampler algorithm 'fixed_param' is specified, metric_type is None.
358384 """
385+ if self ._is_fixed_param :
386+ return None
387+ if not self ._validate_csv and self ._metric_type is None :
388+ self ._logger .warning (
389+ 'csv files not yet validated, run method validate_csv_files()'
390+ ' in order to retrieve sample metadata.'
391+ )
392+ return None
359393 return self ._metric_type
360394
361395 @property
@@ -364,7 +398,15 @@ def metric(self) -> np.ndarray:
364398 Metric used by sampler for each chain.
365399 When sampler algorithm 'fixed_param' is specified, metric is None.
366400 """
367- if not self ._is_fixed_param and self ._metric is None :
401+ if self ._is_fixed_param :
402+ return None
403+ if not self ._validate_csv and self ._metric is None :
404+ self ._logger .warning (
405+ 'csv files not yet validated, run method validate_csv_files()'
406+ ' in order to retrieve sample metadata.'
407+ )
408+ return None
409+ if self ._sample is None :
368410 self ._assemble_sample ()
369411 return self ._metric
370412
@@ -374,7 +416,15 @@ def stepsize(self) -> np.ndarray:
374416 Stepsize used by sampler for each chain.
375417 When sampler algorithm 'fixed_param' is specified, stepsize is None.
376418 """
377- if not self ._is_fixed_param and self ._stepsize is None :
419+ if self ._is_fixed_param :
420+ return None
421+ if not self ._validate_csv and self ._stepsize is None :
422+ self ._logger .warning (
423+ 'csv files not yet validated, run method validate_csv_files()'
424+ ' in order to retrieve sample metadata.'
425+ )
426+ return None
427+ if self ._sample is None :
378428 self ._assemble_sample ()
379429 return self ._stepsize
380430
@@ -386,6 +436,8 @@ def sample(self) -> np.ndarray:
386436 so that the values for each parameter are stored contiguously
387437 in memory, likewise all draws from a chain are contiguous.
388438 """
439+ if not self ._validate_csv and self ._sample is None :
440+ self .validate_csv_files ()
389441 if self ._sample is None :
390442 self ._assemble_sample ()
391443 return self ._sample
@@ -400,11 +452,13 @@ def warmup(self) -> np.ndarray:
400452 """
401453 if not self ._save_warmup :
402454 return None
455+ if not self ._validate_csv and self ._sample is None :
456+ self .validate_csv_files ()
403457 if self ._sample is None :
404458 self ._assemble_sample ()
405459 return self ._warmup
406460
407- def _validate_csv_files (self ) -> None :
461+ def validate_csv_files (self ) -> None :
408462 """
409463 Checks that csv output files for all chains are consistent.
410464 Populates attributes for draws, column_names, num_params, metric_type.
0 commit comments