@@ -504,3 +504,52 @@ def compute_loss(self, predictions, y_true):
504504 # Compute the negative log-likelihood
505505 nll = - cat_dist .log_prob (y_true ).mean ()
506506 return nll
507+
508+
509+ class Quantile (BaseDistribution ):
510+ """
511+ Quantile Regression Loss class.
512+
513+ This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
514+ It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
515+
516+ Parameters
517+ ----------
518+ name : str, optional
519+ The name of the distribution, by default "Quantile".
520+ quantiles : list of float, optional
521+ A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
522+
523+ Attributes
524+ ----------
525+ quantiles : list of float
526+ List of quantiles for which the pinball loss is computed.
527+
528+ Methods
529+ -------
530+ compute_loss(predictions, y_true)
531+ Computes the quantile regression loss between the predictions and true values.
532+ """
533+
534+ def __init__ (self , name = "Quantile" , quantiles = [0.25 , 0.5 , 0.75 ]):
535+ param_names = [
536+ f"q_{ q } " for q in quantiles
537+ ] # Use string representations of quantiles
538+ super ().__init__ (name , param_names )
539+ self .quantiles = quantiles
540+
541+ def compute_loss (self , predictions , y_true ):
542+
543+ assert not y_true .requires_grad # Ensure y_true does not require gradients
544+ assert predictions .size (0 ) == y_true .size (0 ) # Ensure batch size matches
545+
546+ losses = []
547+ for i , q in enumerate (self .quantiles ):
548+ errors = y_true - predictions [:, i ] # Calculate errors for each quantile
549+ # Compute the pinball loss
550+ quantile_loss = torch .max ((q - 1 ) * errors , q * errors )
551+ losses .append (quantile_loss )
552+
553+ # Sum losses across quantiles and compute mean
554+ loss = torch .mean (torch .stack (losses , dim = 1 ).sum (dim = 1 ))
555+ return loss
0 commit comments