Raw File
tuple_select.py
import halide as hl
import numpy as np


def test_tuple_select():
    x = hl.Var('x')
    y = hl.Var('y')

    # ternary tuple_select with Expr condition
    f = hl.Func('f')
    f[x, y] = hl.tuple_select(x + y < 30, (x, y), (x-1, y-2))

    a, b = f.realize(200, 200)
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx + yy < 30 else xx-1
            correct_b = yy if xx + yy < 30 else yy-2
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # ternary tuple_select with Tuple condition
    f = hl.Func('f')
    f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y), (x-1, y-2))

    a, b = f.realize(200, 200)
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx < 30 else xx-1
            correct_b = yy if yy < 30 else yy-2
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # multiway tuple_select with Expr condition
    f = hl.Func('f')
    f[x, y] = hl.tuple_select(x + y < 30,  (x, y),
                              x + y < 100, (x-1, y-2),
                                           (x-100, y-200))

    a, b = f.realize(200, 200)
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx + yy < 30 else xx-1 if xx + yy < 100 else xx - 100
            correct_b = yy if xx + yy < 30 else yy-2 if xx + yy < 100 else yy - 200
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # multiway tuple_select with Tuple condition
    f = hl.Func('f')
    f[x, y] = hl.tuple_select((x < 30, y < 30),   (x, y),
                              (x < 100, y < 100), (x-1, y-2),
                                                  (x-100, y-200))

    a, b = f.realize(200, 200)
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx < 30 else xx-1 if xx < 100 else xx - 100
            correct_b = yy if yy < 30 else yy-2 if yy < 100 else yy - 200
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # Failure case: mixing Expr and Tuple in multiway
    try:
        f = hl.Func('f')
        f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y),
                                   x + y < 100,     (x-1, y-2),
                                                    (x-100, y-200))
    except RuntimeError as e:
        assert 'tuple_select() may not mix Expr and Tuple for the condition elements.' in str(e)
    else:
        assert False, 'Did not see expected exception!'

    # Failure case: Tuples of mixed sizes
    try:
        f = hl.Func('f')
        f[x, y] = hl.tuple_select((x < 30, y < 30), (x, y, 0), (1, 2, 3, 4))
    except RuntimeError as e:
        assert 'tuple_select() requires all Tuples to have identical sizes' in str(e)
    else:
        assert False, 'Did not see expected exception!'


if __name__ == "__main__":
    test_tuple_select()
back to top