Skip to content

Commit 318e6e0

Browse files
authored
Merge pull request #283 from b-r-oleary/bro-fix-variational-args
Some fixes/additions to the variational arguments
2 parents baad6ce + 1950b27 commit 318e6e0

5 files changed

Lines changed: 54 additions & 26 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def validate(self, chains: int) -> None:
132132
)
133133
if self.step_size is not None:
134134
if isinstance(self.step_size, Real):
135-
if self.step_size < 0:
135+
if self.step_size <= 0:
136136
raise ValueError(
137137
'step_size must be > 0, found {}'.format(self.step_size)
138138
)
@@ -336,7 +336,7 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
336336
'init_alpha must not be set when algorithm is Newton'
337337
)
338338
if isinstance(self.init_alpha, Real):
339-
if self.init_alpha < 0:
339+
if self.init_alpha <= 0:
340340
raise ValueError('init_alpha must be greater than 0')
341341
else:
342342
raise ValueError('init_alpha must be type of float')
@@ -403,6 +403,7 @@ def __init__(
403403
elbo_samples: int = None,
404404
eta: Real = None,
405405
adapt_iter: int = None,
406+
adapt_engaged: bool = True,
406407
tol_rel_obj: Real = None,
407408
eval_elbo: int = None,
408409
output_samples: int = None,
@@ -413,6 +414,7 @@ def __init__(
413414
self.elbo_samples = elbo_samples
414415
self.eta = eta
415416
self.adapt_iter = adapt_iter
417+
self.adapt_engaged = adapt_engaged
416418
self.tol_rel_obj = tol_rel_obj
417419
self.eval_elbo = eval_elbo
418420
self.output_samples = output_samples
@@ -453,19 +455,19 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
453455
' found {}'.format(self.elbo_samples)
454456
)
455457
if self.eta is not None:
456-
if self.eta < 1 or not isinstance(self.eta, (Integral, Real)):
458+
if self.eta < 0 or not isinstance(self.eta, (Integral, Real)):
457459
raise ValueError(
458460
'eta must be a non-negative number,'
459461
' found {}'.format(self.eta)
460462
)
461463
if self.adapt_iter is not None:
462-
if self.adapt_iter < 1 or not isinstance(self.eta, Integral):
464+
if self.adapt_iter < 1 or not isinstance(self.adapt_iter, Integral):
463465
raise ValueError(
464466
'adapt_iter must be a positive integer,'
465467
' found {}'.format(self.adapt_iter)
466468
)
467469
if self.tol_rel_obj is not None:
468-
if self.tol_rel_obj < 1 or not isinstance(
470+
if self.tol_rel_obj <= 0 or not isinstance(
469471
self.tol_rel_obj, (Integral, Real)
470472
):
471473
raise ValueError(
@@ -503,9 +505,13 @@ def compose(self, idx: int, cmd: List) -> str:
503505
cmd.append('elbo_samples={}'.format(self.elbo_samples))
504506
if self.eta is not None:
505507
cmd.append('eta={}'.format(self.eta))
506-
if self.adapt_iter is not None:
507-
cmd.append('adapt')
508-
cmd.append('iter={}'.format(self.adapt_iter))
508+
cmd.append('adapt')
509+
if self.adapt_engaged:
510+
cmd.append('engaged=1')
511+
if self.adapt_iter is not None:
512+
cmd.append('iter={}'.format(self.adapt_iter))
513+
else:
514+
cmd.append('engaged=0')
509515
if self.tol_rel_obj is not None:
510516
cmd.append('tol_rel_obj={}'.format(self.tol_rel_obj))
511517
if self.eval_elbo is not None:

cmdstanpy/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,10 +900,12 @@ def variational(
900900
grad_samples: int = None,
901901
elbo_samples: int = None,
902902
eta: Real = None,
903+
adapt_engaged: bool = True,
903904
adapt_iter: int = None,
904905
tol_rel_obj: Real = None,
905906
eval_elbo: int = None,
906907
output_samples: int = None,
908+
require_converged: bool = True,
907909
) -> CmdStanVB:
908910
"""
909911
Run CmdStan's variational inference algorithm to approximate
@@ -961,6 +963,8 @@ def variational(
961963
962964
:param eta: Stepsize scaling parameter.
963965
966+
:param adapt_engaged: Whether eta adaptation is engaged.
967+
964968
:param adapt_iter: Number of iterations for eta adaptation.
965969
966970
:param tol_rel_obj: Relative tolerance parameter for convergence.
@@ -970,6 +974,9 @@ def variational(
970974
:param output_samples: Number of approximate posterior output draws
971975
to save.
972976
977+
:param require_converged: Whether or not to raise an error if stan
978+
reports that "The algorithm may not have converged".
979+
973980
:return: CmdStanVB object
974981
"""
975982
variational_args = VariationalArgs(
@@ -978,6 +985,7 @@ def variational(
978985
grad_samples=grad_samples,
979986
elbo_samples=elbo_samples,
980987
eta=eta,
988+
adapt_engaged=adapt_engaged,
981989
adapt_iter=adapt_iter,
982990
tol_rel_obj=tol_rel_obj,
983991
eval_elbo=eval_elbo,
@@ -1010,7 +1018,7 @@ def variational(
10101018
errors = re.findall(pat, contents)
10111019
if len(errors) > 0:
10121020
valid = False
1013-
if not valid:
1021+
if require_converged and not valid:
10141022
raise RuntimeError('The algorithm may not have converged.')
10151023
if not runset._check_retcodes():
10161024
msg = 'Error during variational inference.\n{}'.format(

cmdstanpy/utils.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -553,22 +553,17 @@ def scan_variational_csv(path: str) -> Dict:
553553
lineno = scan_column_names(fd, dict, lineno)
554554
line = fd.readline().lstrip(' #\t').rstrip()
555555
lineno += 1
556-
if not line.startswith('Stepsize adaptation complete.'):
557-
raise ValueError(
558-
'line {}: expecting adaptation msg, found:\n\t "{}"'.format(
559-
lineno, line
560-
)
561-
)
562-
line = fd.readline().lstrip(' #\t\n')
563-
lineno += 1
564-
if not line.startswith('eta = 1'):
565-
raise ValueError(
566-
'line {}: expecting eta = 1, found:\n\t "{}"'.format(
567-
lineno, line
556+
if line.startswith('Stepsize adaptation complete.'):
557+
line = fd.readline().lstrip(' #\t\n')
558+
lineno += 1
559+
if not line.startswith('eta'):
560+
raise ValueError(
561+
'line {}: expecting eta, found:\n\t "{}"'.format(
562+
lineno, line
563+
)
568564
)
569-
)
570-
line = fd.readline().lstrip(' #\t\n')
571-
lineno += 1
565+
line = fd.readline().lstrip(' #\t\n')
566+
lineno += 1
572567
xs = line.split(',')
573568
variational_mean = [float(x) for x in xs]
574569
dict['variational_mean'] = variational_mean

test/test_cmdstan_args.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,11 +601,26 @@ def test_args_variational(self):
601601
self.assertIn('method=variational', ' '.join(cmd))
602602
self.assertIn('output_samples=1', ' '.join(cmd))
603603

604-
args = VariationalArgs(tol_rel_obj=1)
604+
args = VariationalArgs(tol_rel_obj=0.01)
605605
args.validate(chains=1)
606606
cmd = args.compose(idx=0, cmd=[])
607607
self.assertIn('method=variational', ' '.join(cmd))
608-
self.assertIn('tol_rel_obj=1', ' '.join(cmd))
608+
self.assertIn('tol_rel_obj=0.01', ' '.join(cmd))
609+
610+
args = VariationalArgs(adapt_engaged=True, adapt_iter=100)
611+
args.validate(chains=1)
612+
cmd = args.compose(idx=0, cmd=[])
613+
self.assertIn('adapt engaged=1 iter=100', ' '.join(cmd))
614+
615+
args = VariationalArgs(adapt_engaged=False)
616+
args.validate(chains=1)
617+
cmd = args.compose(idx=0, cmd=[])
618+
self.assertIn('adapt engaged=0', ' '.join(cmd))
619+
620+
args = VariationalArgs(eta=0.1)
621+
args.validate(chains=1)
622+
cmd = args.compose(idx=0, cmd=[])
623+
self.assertIn('eta=0.1', ' '.join(cmd))
609624

610625
def test_args_bad(self):
611626
args = VariationalArgs(algorithm='no_such_algo')

test/test_variational.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def test_variational_eta_fail(self):
147147
):
148148
model.variational(algorithm='meanfield', seed=12345)
149149

150+
model.variational(
151+
algorithm='meanfield', seed=12345, require_converged=False
152+
)
153+
150154

151155
if __name__ == '__main__':
152156
unittest.main()

0 commit comments

Comments
 (0)