https://github.com/alvinwan/neural-backed-decision-trees
Raw File
Tip revision: a7a2ee6f735bbc1b3d8c7c4f9ecdd02c6a75fc1e authored by Alvin Wan on 03 June 2021, 04:38:35 UTC
Merge pull request #20 from alvinwan/dependabot/pip/examples/app/flask-cors-3.0.9
Tip revision: a7a2ee6
metrics.py
import torch


__all__ = names = ("top1", "top2", "top5", "top10")


class TopK:
    def __init__(self, k=1):
        self.k = k
        self.clear()

    def clear(self):
        self.correct = 0
        self.total = 0

    def forward(self, outputs, targets):
        _, preds = torch.topk(outputs, self.k)
        results = [(pred == target).any() for pred, target in zip(preds, targets)]
        self.correct += sum(results).item()
        self.total += targets.size(0)

    def report(self):
        return self.correct / (self.total or 1)

    def __repr__(self):
        return f"Top{self.k}: {self.report()}"

    def __str__(self):
        return repr(self)


top1 = lambda: TopK(1)
top2 = lambda: TopK(2)
top5 = lambda: TopK(5)
top10 = lambda: TopK(10)
back to top