@@ -132,7 +132,7 @@ def validate(self, chains: int) -> None:
132132 )
133133 if self .step_size is not None :
134134 if isinstance (self .step_size , Real ):
135- if self .step_size < 0 :
135+ if self .step_size <= 0 :
136136 raise ValueError (
137137 'step_size must be > 0, found {}' .format (self .step_size )
138138 )
@@ -336,7 +336,7 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
336336 'init_alpha must not be set when algorithm is Newton'
337337 )
338338 if isinstance (self .init_alpha , Real ):
339- if self .init_alpha < 0 :
339+ if self .init_alpha <= 0 :
340340 raise ValueError ('init_alpha must be greater than 0' )
341341 else :
342342 raise ValueError ('init_alpha must be type of float' )
@@ -403,6 +403,7 @@ def __init__(
403403 elbo_samples : int = None ,
404404 eta : Real = None ,
405405 adapt_iter : int = None ,
406+ adapt_engaged : bool = True ,
406407 tol_rel_obj : Real = None ,
407408 eval_elbo : int = None ,
408409 output_samples : int = None ,
@@ -413,6 +414,7 @@ def __init__(
413414 self .elbo_samples = elbo_samples
414415 self .eta = eta
415416 self .adapt_iter = adapt_iter
417+ self .adapt_engaged = adapt_engaged
416418 self .tol_rel_obj = tol_rel_obj
417419 self .eval_elbo = eval_elbo
418420 self .output_samples = output_samples
@@ -453,19 +455,19 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
453455 ' found {}' .format (self .elbo_samples )
454456 )
455457 if self .eta is not None :
456- if self .eta < 1 or not isinstance (self .eta , (Integral , Real )):
458+ if self .eta < 0 or not isinstance (self .eta , (Integral , Real )):
457459 raise ValueError (
458460 'eta must be a non-negative number,'
459461 ' found {}' .format (self .eta )
460462 )
461463 if self .adapt_iter is not None :
462- if self .adapt_iter < 1 or not isinstance (self .eta , Integral ):
464+ if self .adapt_iter < 1 or not isinstance (self .adapt_iter , Integral ):
463465 raise ValueError (
464466 'adapt_iter must be a positive integer,'
465467 ' found {}' .format (self .adapt_iter )
466468 )
467469 if self .tol_rel_obj is not None :
468- if self .tol_rel_obj < 1 or not isinstance (
470+ if self .tol_rel_obj <= 0 or not isinstance (
469471 self .tol_rel_obj , (Integral , Real )
470472 ):
471473 raise ValueError (
@@ -503,9 +505,13 @@ def compose(self, idx: int, cmd: List) -> str:
503505 cmd .append ('elbo_samples={}' .format (self .elbo_samples ))
504506 if self .eta is not None :
505507 cmd .append ('eta={}' .format (self .eta ))
506- if self .adapt_iter is not None :
507- cmd .append ('adapt' )
508- cmd .append ('iter={}' .format (self .adapt_iter ))
508+ cmd .append ('adapt' )
509+ if self .adapt_engaged :
510+ cmd .append ('engaged=1' )
511+ if self .adapt_iter is not None :
512+ cmd .append ('iter={}' .format (self .adapt_iter ))
513+ else :
514+ cmd .append ('engaged=0' )
509515 if self .tol_rel_obj is not None :
510516 cmd .append ('tol_rel_obj={}' .format (self .tol_rel_obj ))
511517 if self .eval_elbo is not None :
0 commit comments