Revision 869027cc7911b638b8292c67182fb6f2b8f0f3c9 authored by dependabot[bot] on 21 June 2022, 22:33:57 UTC, committed by GitHub on 21 June 2022, 22:33:57 UTC
Bumps [numpy](https://github.com/numpy/numpy) from 1.16.4 to 1.22.0.
- [Release notes](https://github.com/numpy/numpy/releases)
- [Changelog](https://github.com/numpy/numpy/blob/main/doc/HOWTO_RELEASE.rst)
- [Commits](https://github.com/numpy/numpy/compare/v1.16.4...v1.22.0)

---
updated-dependencies:
- dependency-name: numpy
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
1 parent 78da4ea
Raw File
loss.py
"""
Losses
"""
# pylint: disable=C0301,C0103,R0902,R0915,W0221,W0622


##
# LIBRARIES
import torch

##
def l1_loss(input, target):
    """ L1 Loss without reduce flag.

    Args:
        input (FloatTensor): Input tensor
        target (FloatTensor): Output tensor

    Returns:
        [FloatTensor]: L1 distance between input and output
    """

    return torch.mean(torch.abs(input - target))

##
def l2_loss(input, target, size_average=True):
    """ L2 Loss without reduce flag.

    Args:
        input (FloatTensor): Input tensor
        target (FloatTensor): Output tensor

    Returns:
        [FloatTensor]: L2 distance between input and output
    """
    if size_average:
        return torch.mean(torch.pow((input-target), 2))
    else:
        return torch.pow((input-target), 2)
back to top