https://github.com/JuliaLang/julia
Raw File
Tip revision: b372a68743c0139797316fd8b4a95d497ba6d8f0 authored by Keno Fischer on 26 October 2013, 02:06:56 UTC
Tag v0.2.0-rc2
Tip revision: b372a68
reduce.jl
## reductions ##

function reduce(op::Function, itr) # this is a left fold
    if is(op,+)
        return sum(itr)
    elseif is(op,*)
        return prod(itr)
    elseif is(op,|)
        return any(itr)
    elseif is(op,&)
        return all(itr)
    end
    s = start(itr)
    if done(itr, s)
        return op()  # empty collection
    end
    (v, s) = next(itr, s)
    while !done(itr, s)
        (x, s) = next(itr, s)
        v = op(v,x)
    end
    return v
end

function maximum(itr)
    s = start(itr)
    if done(itr, s)
        error("maximum: argument is empty")
    end
    (v, s) = next(itr, s)
    while !done(itr, s)
        (x, s) = next(itr, s)
        v = scalarmax(v,x)
    end
    return v
end

function minimum(itr)
    s = start(itr)
    if done(itr, s)
        error("minimum: argument is empty")
    end
    (v, s) = next(itr, s)
    while !done(itr, s)
        (x, s) = next(itr, s)
        v = scalarmin(v,x)
    end
    return v
end

function sum(itr)
    s = start(itr)
    if done(itr, s)
        return +()
    end
    (v, s) = next(itr, s)
    while !done(itr, s)
        (x, s) = next(itr, s)
        v = v+x
    end
    return v
end

function prod(itr)
    s = start(itr)
    if done(itr, s)
        return *()
    end
    (v, s) = next(itr, s)
    while !done(itr, s)
        (x, s) = next(itr, s)
        v = v*x
    end
    return v
end

function reduce(op::Function, v0, itr)
    v = v0
    if is(op,+)
        for x in itr
            v = v+x
        end
    elseif is(op,*)
        for x in itr
            v = v*x
        end
    else
        u = v0
        for x in itr
            u = op(u,x)
        end
        return u
    end
    return v
end

##
# generic map on any iterator
function map(f::Callable, iters...)
    result = {}
    len = length(iters)
    states = [start(iters[idx]) for idx in 1:len]
    nxtvals = cell(len)
    cont = true
    for idx in 1:len
        done(iters[idx], states[idx]) && (cont = false; break)
    end
    while cont
        for idx in 1:len
            nxtvals[idx],states[idx] = next(iters[idx], states[idx])
        end
        push!(result, f(nxtvals...))
        for idx in 1:len
            done(iters[idx], states[idx]) && (cont = false; break)
        end
    end
    result
end

function mapreduce(f::Callable, op::Function, itr)
    s = start(itr)
    if done(itr, s)
        return op()  # empty collection
    end
    (x, s) = next(itr, s)
    v = f(x)
    while !done(itr, s)
        (x, s) = next(itr, s)
        v = op(v,f(x))
    end
    return v
end

function mapreduce(f::Callable, op::Function, v0, itr)
    v = v0
    for x in itr
        v = op(v,f(x))
    end
    return v
end

# mapreduce for random-access arrays, using pairwise recursive reduction
# for improved accuracy (see sum_pairwise)
function mr_pairwise(f::Callable, op::Function, A::AbstractArray, i1,n)
    if n < 128
        @inbounds v = f(A[i1])
        for i = i1+1:i1+n-1
            @inbounds v = op(v,f(A[i]))
        end
        return v
    else
        n2 = div(n,2)
        return op(mr_pairwise(f,op,A, i1,n2), mr_pairwise(f,op,A, i1+n2,n-n2))
    end
end
function mapreduce(f::Callable, op::Function, A::AbstractArray)
    n = length(A)
    n == 0 ? op() : mr_pairwise(f,op,A, 1,n)
end
function mapreduce(f::Callable, op::Function, v0, A::AbstractArray)
    n = length(A)
    n == 0 ? v0 : op(v0, mr_pairwise(f,op,A, 1,n))
end

function any(itr)
    for x in itr
        if x
            return true
        end
    end
    return false
end

function all(itr)
    for x in itr
        if !x
            return false
        end
    end
    return true
end

maximum(f::Function, itr) = mapreduce(f, scalarmax, itr)
minimum(f::Function, itr) = mapreduce(f, scalarmin, itr)
sum(f::Function, itr)     = mapreduce(f, +        , itr)
prod(f::Function, itr)    = mapreduce(f, *        , itr)

function count(pred::Function, itr)
    s = 0
    for x in itr
        if pred(x)
            s+=1
        end
    end
    s
end

function any(pred::Function, itr)
    for x in itr
        if pred(x)
            return true
        end
    end
    return false
end

function all(pred::Function, itr)
    for x in itr
        if !pred(x)
            return false
        end
    end
    return true
end

function in(x, itr)
    for y in itr
        if isequal(y,x)
            return true
        end
    end
    return false
end

function contains(itr, x)
    depwarn("contains(collection, item) is deprecated, use in(item, collection) instead", :contains)
    in(x, itr)
end

function contains(eq::Function, itr, x)
    for y in itr
        if eq(y,x)
            return true
        end
    end
    return false
end
back to top