Correlation loss / metric 구현해보자(with Pytorch)

dev_halo·2022년 4월 6일
0

Python

목록 보기
8/8

Loss

def correlation_loss(y_pred, y_true):
    x = y_pred.clone()
    y = y_true.clone()
    vx = x - torch.mean(x)
    vy = y - torch.mean(y)
    cov = torch.sum(vx * vy)
    corr = cov / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)) + 1e-12)
    corr = torch.maximum(torch.minimum(corr,torch.tensor(1)), torch.tensor(-1))
    return torch.sub(torch.tensor(1), corr ** 2)

Metric

def correlation_metric(y_pred, y_true):
        x = torch.Tensor(y_pred)
        y = torch.Tensor(y_true)
        vx = x - torch.mean(x)
        vy = y - torch.mean(y)
        cov = torch.sum(vx * vy)
        corr = cov / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)) + 1e-12)
        return corr
profile
일단 해보자 !

0개의 댓글