Skip to content

Commit 5f4256e

Browse files
committed
Merge branch 'master' of https://github.com/stan-dev/cmdstanpy
2 parents ac80e60 + 2aaf60f commit 5f4256e

5 files changed

Lines changed: 150 additions & 43 deletions

File tree

cmdstanpy/model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def sample(
470470
:param parallel_chains: Number of processes to run in parallel. Must be
471471
a positive integer. Defaults to ``multiprocessing.cpu_count()``.
472472
473-
:param threads_per_chain: the number of threads to use in parallelized
473+
:param threads_per_chain: The number of threads to use in parallelized
474474
sections within an MCMC chain (e.g., when using the Stan functions
475475
``reduce_sum()`` or ``map_rect()``). This will only have an effect
476476
if the model was compiled with threading support. The total number
@@ -597,12 +597,12 @@ def sample(
597597
chain_ids = [x + 1 for x in range(chains)]
598598
else:
599599
if isinstance(chain_ids, int):
600-
if chain_ids < 0:
600+
if chain_ids < 1:
601601
raise ValueError(
602-
'Chain_id must be a non-negative integer value,'
602+
'Chain_id must be a positive integer value,'
603603
' found {}.'.format(chain_ids)
604604
)
605-
chain_ids = [chain_ids + i + 1 for i in range(chains)]
605+
chain_ids = [chain_ids + i for i in range(chains)]
606606
else:
607607
if not len(chain_ids) == chains:
608608
raise ValueError(
@@ -691,7 +691,7 @@ def sample(
691691
refresh=refresh,
692692
logger=self._logger,
693693
)
694-
runset = RunSet(args=args, chains=chains)
694+
runset = RunSet(args=args, chains=chains, chain_ids=chain_ids)
695695
pbar = None
696696
all_pbars = []
697697

@@ -818,16 +818,19 @@ def generate_quantities(
818818
sample_csv_files = mcmc_sample.runset.csv_files
819819
sample_drawset = mcmc_sample.get_drawset()
820820
chains = mcmc_sample.chains
821+
chain_ids = mcmc_sample.chain_ids
821822
elif isinstance(mcmc_sample, list):
823+
if len(mcmc_sample) < 1:
824+
raise ValueError('MCMC sample cannot be empty list')
822825
sample_csv_files = mcmc_sample
826+
chains = len(sample_csv_files)
827+
chain_ids = [x + 1 for x in range(chains)]
823828
else:
824829
raise ValueError(
825830
'MCMC sample must be either CmdStanMCMC object'
826831
' or list of paths to sample csv_files.'
827832
)
828-
829833
try:
830-
chains = len(sample_csv_files)
831834
if sample_drawset is None: # assemble sample from csv files
832835
config = {}
833836
# scan 1st csv file to get config
@@ -852,10 +855,10 @@ def generate_quantities(
852855
args = CmdStanArgs(
853856
self._name,
854857
self._exe_file,
855-
chain_ids=[x + 1 for x in range(chains)],
858+
chain_ids=chain_ids,
856859
method_args=sampler_args,
857860
)
858-
runset = RunSet(args=args, chains=chains)
861+
runset = RunSet(args=args, chains=chains, chain_ids=chain_ids)
859862
runset._csv_files = sample_csv_files
860863
sample_fit = CmdStanMCMC(runset)
861864
sample_drawset = sample_fit.get_drawset()
@@ -875,13 +878,13 @@ def generate_quantities(
875878
args = CmdStanArgs(
876879
self._name,
877880
self._exe_file,
878-
chain_ids=[x + 1 for x in range(chains)],
881+
chain_ids=chain_ids,
879882
data=_data,
880883
seed=seed,
881884
output_dir=gq_output_dir,
882885
method_args=generate_quantities_args,
883886
)
884-
runset = RunSet(args=args, chains=chains)
887+
runset = RunSet(args=args, chains=chains, chain_ids=chain_ids)
885888

886889
parallel_chains_avail = cpu_count()
887890
parallel_chains = max(min(parallel_chains_avail - 2, chains), 1)

cmdstanpy/stanfit.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ class RunSet:
3434
"""
3535

3636
def __init__(
37-
self, args: CmdStanArgs, chains: int = 4, logger: logging.Logger = None
37+
self,
38+
args: CmdStanArgs,
39+
chains: int = 4,
40+
chain_ids: List[int] = None,
41+
logger: logging.Logger = None,
3842
) -> None:
3943
"""Initialize object."""
4044
self._args = args
@@ -45,7 +49,16 @@ def __init__(
4549
'chains must be positive integer value, '
4650
'found {}'.format(chains)
4751
)
48-
52+
if chain_ids is None:
53+
chain_ids = [x + 1 for x in range(chains)]
54+
elif len(chain_ids) != chains:
55+
raise ValueError(
56+
'mismatch between number of chains and chain_ids, '
57+
'found {} chains, but {} chain_ids'.format(
58+
chains, len(chain_ids)
59+
)
60+
)
61+
self._chain_ids = chain_ids
4962
self._retcodes = [-1 for _ in range(chains)]
5063

5164
# stdout, stderr are written to text files
@@ -67,12 +80,13 @@ def __init__(
6780
if args.output_dir is None:
6881
csv_file = create_named_text_file(
6982
dir=output_dir,
70-
prefix='{}-{}-'.format(file_basename, i + 1),
83+
prefix='{}-{}-'.format(file_basename, str(chain_ids[i])),
7184
suffix='.csv',
7285
)
7386
else:
7487
csv_file = os.path.join(
75-
output_dir, '{}-{}.{}'.format(file_basename, i + 1, 'csv')
88+
output_dir,
89+
'{}-{}.{}'.format(file_basename, str(chain_ids[i]), 'csv'),
7690
)
7791
self._csv_files[i] = csv_file
7892
stdout_file = ''.join(
@@ -87,14 +101,16 @@ def __init__(
87101
if args.output_dir is None:
88102
diag_file = create_named_text_file(
89103
dir=_TMPDIR,
90-
prefix='{}-diagnostic-{}-'.format(file_basename, i + 1),
104+
prefix='{}-diagnostic-{}-'.format(
105+
file_basename, str(chain_ids[i])
106+
),
91107
suffix='.csv',
92108
)
93109
else:
94110
diag_file = os.path.join(
95111
output_dir,
96112
'{}-diagnostic-{}.{}'.format(
97-
file_basename, i + 1, 'csv'
113+
file_basename, str(chain_ids[i]), 'csv'
98114
),
99115
)
100116
self._diagnostic_files[i] = diag_file
@@ -126,9 +142,14 @@ def method(self) -> Method:
126142

127143
@property
128144
def chains(self) -> int:
129-
"""Number of sampler chains."""
145+
"""Number of chains."""
130146
return self._chains
131147

148+
@property
149+
def chain_ids(self) -> List[int]:
150+
"""Chain ids."""
151+
return self._chain_ids
152+
132153
@property
133154
def cmds(self) -> List[str]:
134155
"""Per-chain call to CmdStan."""
@@ -297,6 +318,11 @@ def chains(self) -> int:
297318
"""Number of chains."""
298319
return self.runset.chains
299320

321+
@property
322+
def chain_ids(self) -> List[int]:
323+
"""Chain ids."""
324+
return self.runset.chain_ids
325+
300326
@property
301327
def num_draws(self) -> int:
302328
"""Number of post-warmup draws per chain."""
@@ -499,12 +525,34 @@ def _assemble_sample(self) -> None:
499525
xs = line.split(',')
500526
self._sample[i, chain, :] = [float(x) for x in xs]
501527

502-
def summary(self) -> pd.DataFrame:
528+
def summary(self, percentiles: List[int] = None) -> pd.DataFrame:
503529
"""
504530
Run cmdstan/bin/stansummary over all output csv files.
505531
Echo stansummary stdout/stderr to console.
506532
Assemble csv tempfile contents into pandasDataFrame.
533+
534+
:param percentiles: Ordered non-empty list of percentiles to report.
535+
Must be integers from (1, 99), inclusive.
507536
"""
537+
percentiles_str = '--percentiles=5,50,95'
538+
if percentiles is not None:
539+
if len(percentiles) == 0:
540+
raise ValueError(
541+
'invalid percentiles argument, must be ordered'
542+
' non-empty list from (1, 99), inclusive.'
543+
)
544+
545+
cur_pct = 0
546+
for pct in percentiles:
547+
if pct > 99 or not pct > cur_pct:
548+
raise ValueError(
549+
'invalid percentiles spec, must be ordered'
550+
' non-empty list from (1, 99), inclusive.'
551+
)
552+
cur_pct = pct
553+
percentiles_str = '='.join(
554+
['--percentiles', ','.join([str(x) for x in percentiles])]
555+
)
508556
cmd_path = os.path.join(
509557
cmdstan_path(), 'bin', 'stansummary' + EXTENSION
510558
)
@@ -516,6 +564,7 @@ def summary(self) -> pd.DataFrame:
516564
)
517565
cmd = [
518566
cmd_path,
567+
percentiles_str,
519568
'--csv_file={}'.format(tmp_csv_path),
520569
] + self.runset.csv_files
521570
do_command(cmd, logger=self.runset._logger)

test/test_generate_quantities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_gen_quantities_csv_files_bad(self):
6060
model = CmdStanModel(stan_file=stan)
6161
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
6262

63+
with self.assertRaises(ValueError):
64+
model.generate_quantities(data=jdata, mcmc_sample=[])
65+
6366
# synthesize list of filenames
6467
goodfiles_path = os.path.join(
6568
DATAFILES_PATH, 'runset-bad', 'bad-draws-bern'

test/test_runset.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ def test_check_retcodes(self):
1616
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
1717
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1818
sampler_args = SamplerArgs()
19+
chain_ids = [1, 2, 3, 4] # default
1920
cmdstan_args = CmdStanArgs(
2021
model_name='bernoulli',
2122
model_exe=exe,
22-
chain_ids=[1, 2, 3, 4],
23+
chain_ids=chain_ids,
2324
data=jdata,
2425
method_args=sampler_args,
2526
)
26-
runset = RunSet(args=cmdstan_args, chains=4)
27+
runset = RunSet(args=cmdstan_args)
2728
self.assertIn('RunSet: chains=4', runset.__repr__())
2829
self.assertIn('method=sample', runset.__repr__())
2930

@@ -44,14 +45,15 @@ def test_get_err_msgs(self):
4445
exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
4546
rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
4647
sampler_args = SamplerArgs()
48+
chain_ids = [1, 2, 3]
4749
cmdstan_args = CmdStanArgs(
4850
model_name='logistic',
4951
model_exe=exe,
50-
chain_ids=[1, 2, 3],
52+
chain_ids=chain_ids,
5153
data=rdata,
5254
method_args=sampler_args,
5355
)
54-
runset = RunSet(args=cmdstan_args, chains=3)
56+
runset = RunSet(args=cmdstan_args, chains=3, chain_ids=chain_ids)
5557
for i in range(3):
5658
runset._set_retcode(i, 70)
5759
stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt'
@@ -64,14 +66,15 @@ def test_output_filenames(self):
6466
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
6567
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
6668
sampler_args = SamplerArgs()
69+
chain_ids = [1, 2, 3, 4]
6770
cmdstan_args = CmdStanArgs(
6871
model_name='bernoulli',
6972
model_exe=exe,
70-
chain_ids=[1, 2, 3, 4],
73+
chain_ids=chain_ids,
7174
data=jdata,
7275
method_args=sampler_args,
7376
)
74-
runset = RunSet(args=cmdstan_args, chains=4)
77+
runset = RunSet(args=cmdstan_args)
7578
self.assertIn('bernoulli-', runset._csv_files[0])
7679
self.assertIn('-1-', runset._csv_files[0])
7780
self.assertIn('-4-', runset._csv_files[3])
@@ -80,17 +83,53 @@ def test_commands(self):
8083
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
8184
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
8285
sampler_args = SamplerArgs()
86+
chain_ids = [1, 2, 3, 4]
8387
cmdstan_args = CmdStanArgs(
8488
model_name='bernoulli',
8589
model_exe=exe,
86-
chain_ids=[1, 2, 3, 4],
90+
chain_ids=chain_ids,
8791
data=jdata,
8892
method_args=sampler_args,
8993
)
90-
runset = RunSet(args=cmdstan_args, chains=4)
94+
runset = RunSet(args=cmdstan_args)
9195
self.assertIn('id=1', runset._cmds[0])
9296
self.assertIn('id=4', runset._cmds[3])
9397

98+
def test_chain_ids(self):
99+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
100+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
101+
sampler_args = SamplerArgs()
102+
chain_ids = [11, 12, 13, 14]
103+
cmdstan_args = CmdStanArgs(
104+
model_name='bernoulli',
105+
model_exe=exe,
106+
chain_ids=chain_ids,
107+
data=jdata,
108+
method_args=sampler_args,
109+
)
110+
runset = RunSet(args=cmdstan_args, chains=4, chain_ids=chain_ids)
111+
self.assertIn('id=11', runset._cmds[0])
112+
self.assertIn('-11-', runset._csv_files[0])
113+
self.assertIn('id=14', runset._cmds[3])
114+
self.assertIn('-14-', runset._csv_files[3])
115+
116+
def test_ctor_checks(self):
117+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
118+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
119+
sampler_args = SamplerArgs()
120+
chain_ids = [11, 12, 13, 14]
121+
cmdstan_args = CmdStanArgs(
122+
model_name='bernoulli',
123+
model_exe=exe,
124+
chain_ids=chain_ids,
125+
data=jdata,
126+
method_args=sampler_args,
127+
)
128+
with self.assertRaises(ValueError):
129+
RunSet(args=cmdstan_args, chains=0)
130+
with self.assertRaises(ValueError):
131+
RunSet(args=cmdstan_args, chains=4, chain_ids=[1, 2, 3])
132+
94133

95134
if __name__ == '__main__':
96135
unittest.main()

0 commit comments

Comments
 (0)