@@ -375,6 +375,9 @@ def variadic_sort(input, size, descending=False):
375375 input (Tensor): input of shape :math:`(B, ...)`
376376 size (LongTensor): size of sets of shape :math:`(N,)`
377377 descending (bool, optional): return ascending or descending order
378+
379+ Returns
380+ (Tensor, LongTensor): sorted values and indexes
378381 """
379382 index2sample = _size_to_index (size )
380383 index2sample = index2sample .view ([- 1 ] + [1 ] * (input .ndim - 1 ))
@@ -445,6 +448,21 @@ def variadic_sample(input, size, num_sample):
445448
446449
447450def variadic_meshgrid (input1 , size1 , input2 , size2 ):
451+ """
452+ Compute the Cartesian product for two batches of sets with variadic sizes.
453+
454+ Suppose there are :math:`N` sets in each input,
455+ and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively.
456+
457+ Parameters:
458+ input1 (Tensor): input of shape :math:`(B_1, ...)`
459+ size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)`
460+ input2 (Tensor): input of shape :math:`(B_2, ...)`
461+ size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)`
462+
463+ Returns
464+ (Tensor, Tensor): the first and the second elements in the Cartesian product
465+ """
448466 grid_size = size1 * size2
449467 local_index = variadic_arange (grid_size )
450468 local_inner_size = size2 .repeat_interleave (grid_size )
@@ -456,6 +474,19 @@ def variadic_meshgrid(input1, size1, input2, size2):
456474
457475
458476def variadic_to_padded (input , size , value = 0 ):
477+ """
478+ Convert a variadic tensor to a padded tensor.
479+
480+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
481+
482+ Parameters:
483+ input (Tensor): input of shape :math:`(B, ...)`
484+ size (LongTensor): size of sets of shape :math:`(N,)`
485+ value (scalar): fill value for padding
486+
487+ Returns:
488+ (Tensor, BoolTensor): padded tensor and mask
489+ """
459490 num_sample = len (size )
460491 max_size = size .max ()
461492 starts = torch .arange (num_sample , device = size .device ) * max_size
@@ -469,6 +500,13 @@ def variadic_to_padded(input, size, value=0):
469500
470501
471502def padded_to_variadic (padded , size ):
503+ """
504+ Convert a padded tensor to a variadic tensor.
505+
506+ Parameters:
507+ padded (Tensor): padded tensor of shape :math:`(N, ...)`
508+ size (LongTensor): size of sets of shape :math:`(N,)`
509+ """
472510 num_sample , max_size = padded .shape [:2 ]
473511 starts = torch .arange (num_sample , device = size .device ) * max_size
474512 ends = starts + size
0 commit comments