Skip to content

Small bug in classfication_inference for csv_data_configuration #224

@HFooladi

Description

@HFooladi

There is a small bug in the examples/property_prediction/csv_data_configuration/classification_inference.py

On line 37, the output of predict function is logit (so it can change from -inf to inf theoretically).

batch_pred = predict(args, model, bg)
if not args['soft_classification']:
    batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())

So, first it should be converted to a number between [0, 1] with sigmoid function, and then it should be used for hard or soft classification label.

batch_logit = predict(args, model, bg)
batch_pred = torch.sigmoid(batch_logit)
if not args['soft_classification']:
    batch_pred = (batch_pred >= 0.5).float()
predictions.append(batch_pred.detach().cpu())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions