module Cartesian export @ngenerate, @nsplat, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate const CARTESIAN_DIMS = 4 ### @ngenerate, for auto-generation of separate versions of functions for different dimensionalities # Examples (deliberately trivial): # @ngenerate N returntype myndims{T,N}(A::Array{T,N}) = N # or alternatively # function gen_body(N::Int) # quote # return $N # end # end # eval(ngenerate(:N, returntypeexpr, :(myndims{T,N}(A::Array{T,N})), gen_body)) # The latter allows you to use a single gen_body function for both ngenerate and # when your function maintains its own method cache (e.g., reduction or broadcasting). # # Special syntax for function prototypes: # @ngenerate N returntype function myfunction(A::AbstractArray, I::NTuple{N, Int}...) # for N = 3 translates to # function myfunction(A::AbstractArray, I_1::Int, I_2::Int, I_3::Int) # and for the generic (cached) case as # function myfunction(A::AbstractArray, I::Int...) # @nextract N I I # with N = length(I). N should _not_ be listed as a parameter of the function unless # earlier arguments use it that way. # To avoid ambiguity, it would be preferable to have some specific syntax for this, such as # myfunction(A::AbstractArray, I::Int...N) # where N can be an integer or symbol. Currently T...N generates a parser error. macro ngenerate(itersym, returntypeexpr, funcexpr) isfuncexpr(funcexpr) || error("Requires a function expression") esc(ngenerate(itersym, returntypeexpr, funcexpr.args[1], N->sreplace!(copy(funcexpr.args[2]), itersym, N))) end # @nsplat takes an expression like # @nsplat N 2:3 myfunction(A, I::NTuple{N,Real}...) = getindex(A, I...) # and generates # myfunction(A, I_1::Real, I_2::Real) = getindex(A, I_1, I_2) # myfunction(A, I_1::Real, I_2::Real, I_3::Real) = getindex(A, I_1, I_2, I_3) # myfunction(A, I::Real...) = getindex(A, I...) # An @nsplat function _cannot_ have any other Cartesian macros in it. # If you omit the range, it uses 1:CARTESIAN_DIMS. macro nsplat(itersym, args...) local rng if length(args) == 1 rng = 1:CARTESIAN_DIMS funcexpr = args[1] elseif length(args) == 2 rangeexpr = args[1] funcexpr = args[2] if !isa(rangeexpr, Expr) || rangeexpr.head != :(:) || length(rangeexpr.args) != 2 error("First argument must be a from:to expression") end rng = rangeexpr.args[1]:rangeexpr.args[2] else error("Wrong number of arguments") end isfuncexpr(funcexpr) || error("Second argument must be a function expression") prototype = funcexpr.args[1] body = funcexpr.args[2] varname, T = get_splatinfo(prototype, itersym) isempty(varname) && error("Last argument must be a splat") explicit = [Expr(:function, resolvesplat!(copy(prototype), varname, T, N), resolvesplats!(copy(body), varname, N)) for N in rng] protosplat = resolvesplat!(copy(prototype), varname, T, 0) protosplat.args[end] = Expr(:..., protosplat.args[end]) splat = Expr(:function, protosplat, body) esc(Expr(:block, explicit..., splat)) end generate1(itersym, prototype, bodyfunc, N::Int, varname, T) = Expr(:function, spliceint!(sreplace!(resolvesplat!(copy(prototype), varname, T, N), itersym, N)), resolvesplats!(bodyfunc(N), varname, N)) function ngenerate(itersym, returntypeexpr, prototype, bodyfunc, dims=1:CARTESIAN_DIMS, makecached::Bool = true) varname, T = get_splatinfo(prototype, itersym) # Generate versions for specific dimensions fdim = [generate1(itersym, prototype, bodyfunc, N, varname, T) for N in dims] if !makecached return Expr(:block, fdim...) end # Generate the generic cache-based version if isempty(varname) setitersym, extractvarargs = :(), N -> nothing else s = symbol(varname) setitersym = hasparameter(prototype, itersym) ? (:(@assert $itersym == length($s))) : (:($itersym = length($s))) extractvarargs = N -> Expr(:block, map(popescape, _nextract(N, s, s).args)...) end fsym = funcsym(prototype) dictname = symbol(string(fsym)*"_cache") fargs = funcargs(prototype) if !isempty(varname) fargs[end] = Expr(:..., fargs[end].args[1]) end flocal = funcrename(copy(prototype), :_F_) F = Expr(:function, resolvesplat!(prototype, varname, T), quote $setitersym if !haskey($dictname, $itersym) gen1 = Base.Cartesian.generate1($(symbol(itersym)), $(Expr(:quote, flocal)), $bodyfunc, $itersym, $varname, $T) $(dictname)[$itersym] = eval(quote local _F_ $gen1 _F_ end) end ($(dictname)[$itersym]($(fargs...)))::$returntypeexpr end) Expr(:block, fdim..., quote let $dictname = Dict{Int,Function}() $F end end) end isfuncexpr(ex::Expr) = ex.head == :function || (ex.head == :(=) && typeof(ex.args[1]) == Expr && ex.args[1].head == :call) isfuncexpr(arg) = false sreplace!(arg, sym, val) = arg function sreplace!(ex::Expr, sym, val) for i = 1:length(ex.args) ex.args[i] = sreplace!(ex.args[i], sym, val) end ex end sreplace!(s::Symbol, sym, val) = s == sym ? val : s # If using the syntax that will need "desplatting", # myfunction(A::AbstractArray, I::NTuple{N, Int}...) # return the variable name (as a string) and type function get_splatinfo(ex::Expr, itersym::Symbol) if ex.head == :call a = ex.args[end] if isa(a, Expr) && a.head == :... && length(a.args) == 1 b = a.args[1] if isa(b, Expr) && b.head == :(::) varname = string(b.args[1]) c = b.args[2] if isa(c, Expr) && c.head == :curly && c.args[1] == :NTuple && c.args[2] == itersym T = c.args[3] return varname, T end end end end "", Nothing end # Replace splatted with desplatted for a specific number of arguments function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr), N::Int) if !isempty(varname) prototype.args[end] = N > 0 ? Expr(:(::), symbol(string(varname, "_1")), T) : Expr(:(::), symbol(varname), T) for i = 2:N push!(prototype.args, Expr(:(::), symbol(string(varname, "_", i)), T)) end end prototype end # Return the generic splatting form, e.g., # myfunction(A::AbstractArray, I::Int...) function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr)) if !isempty(varname) svarname = symbol(varname) prototype.args[end] = Expr(:..., :($svarname::$T)) end prototype end # Desplatting function calls: replace func(a, b, I...) with func(a, b, I_1, I_2, I_3) resolvesplats!(arg, varname, N) = arg function resolvesplats!(ex::Expr, varname, N::Int) if ex.head == :call for i = 2:length(ex.args)-1 resolvesplats!(ex.args[i], varname, N) end a = ex.args[end] if isa(a, Expr) && a.head == :... && a.args[1] == symbol(varname) ex.args[end] = symbol(string(varname, "_1")) for i = 2:N push!(ex.args, symbol(string(varname, "_", i))) end else resolvesplats!(a, varname, N) end else for i = 1:length(ex.args) resolvesplats!(ex.args[i], varname, N) end end ex end # Remove any function parameters that are integers function spliceint!(ex::Expr) if ex.head == :escape return esc(spliceint!(ex.args[1])) end ex.head == :call || error(string(ex, " must be a call")) if isa(ex.args[1], Expr) && ex.args[1].head == :curly args = ex.args[1].args for i = length(args):-1:1 if isa(args[i], Int) deleteat!(args, i) end end end ex end function popescape(ex::Expr) while ex.head == :escape ex = ex.args[1] end ex end # Extract the "function name" function funcsym(prototype::Expr) prototype = popescape(prototype) prototype.head == :call || error(string(prototype, " must be a call")) tmp = prototype.args[1] if isa(tmp, Expr) && tmp.head == :curly tmp = tmp.args[1] end return tmp end function funcrename(prototype::Expr, name::Symbol) prototype = popescape(prototype) prototype.head == :call || error(string(prototype, " must be a call")) tmp = prototype.args[1] if isa(tmp, Expr) && tmp.head == :curly tmp.args[1] = name else prototype.args[1] = name end return prototype end function hasparameter(prototype::Expr, sym::Symbol) prototype = popescape(prototype) prototype.head == :call || error(string(prototype, " must be a call")) tmp = prototype.args[1] if isa(tmp, Expr) && tmp.head == :curly for i = 2:length(tmp.args) if tmp.args[i] == sym return true end end end false end # Extract the symbols of the function arguments funcarg(s::Symbol) = s funcarg(ex::Expr) = ex.args[1] function funcargs(prototype::Expr) prototype = popescape(prototype) prototype.head == :call || error(string(prototype, " must be a call")) map(a->funcarg(a), prototype.args[2:end]) end ### Cartesian-specific macros # Generate nested loops macro nloops(N, itersym, rangeexpr, args...) _nloops(N, itersym, rangeexpr, args...) end _nloops(N::Int, itersym::Symbol, arraysym::Symbol, args::Expr...) = _nloops(N, itersym, :(d->1:size($arraysym,d)), args...) function _nloops(N::Int, itersym::Symbol, rangeexpr::Expr, args::Expr...) if rangeexpr.head != :-> error("Second argument must be an anonymous function expression to compute the range") end if !(1 <= length(args) <= 3) error("Too many arguments") end body = args[end] ex = Expr(:escape, body) for dim = 1:N itervar = inlineanonymous(itersym, dim) rng = inlineanonymous(rangeexpr, dim) preexpr = length(args) > 1 ? inlineanonymous(args[1], dim) : (:(nothing)) postexpr = length(args) > 2 ? inlineanonymous(args[2], dim) : (:(nothing)) ex = quote for $(esc(itervar)) = $(esc(rng)) $(esc(preexpr)) $ex $(esc(postexpr)) end end end ex end # Generate expression A[i1, i2, ...] macro nref(N, A, sym) _nref(N, A, sym) end function _nref(N::Int, A::Symbol, ex) vars = [ inlineanonymous(ex,i) for i = 1:N ] Expr(:escape, Expr(:ref, A, vars...)) end # Generate f(arg1, arg2, ...) macro ncall(N, f, sym...) _ncall(N, f, sym...) end function _ncall(N::Int, f, args...) pre = args[1:end-1] ex = args[end] vars = [ inlineanonymous(ex,i) for i = 1:N ] Expr(:escape, Expr(:call, f, pre..., vars...)) end # Generate N expressions macro nexprs(N, ex) _nexprs(N, ex) end function _nexprs(N::Int, ex::Expr) exs = [ inlineanonymous(ex,i) for i = 1:N ] Expr(:escape, Expr(:block, exs...)) end # Make variables esym1, esym2, ... = isym macro nextract(N, esym, isym) _nextract(N, esym, isym) end function _nextract(N::Int, esym::Symbol, isym::Symbol) aexprs = [Expr(:escape, Expr(:(=), inlineanonymous(esym, i), :(($isym)[$i]))) for i = 1:N] Expr(:block, aexprs...) end function _nextract(N::Int, esym::Symbol, ex::Expr) aexprs = [Expr(:escape, Expr(:(=), inlineanonymous(esym, i), inlineanonymous(ex,i))) for i = 1:N] Expr(:block, aexprs...) end # Check whether variables i1, i2, ... all satisfy criterion macro nall(N, criterion) _nall(N, criterion) end function _nall(N::Int, criterion::Expr) if criterion.head != :-> error("Second argument must be an anonymous function expression yielding the criterion") end conds = [Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N] Expr(:&&, conds...) end macro ntuple(N, ex) _ntuple(N, ex) end function _ntuple(N::Int, ex) vars = [ inlineanonymous(ex,i) for i = 1:N ] Expr(:escape, Expr(:tuple, vars...)) end # if condition1; operation1; elseif condition2; operation2; else operation3 # You can pass one or two operations; the second, if present, is used in the final "else" macro nif(N, condition, operation...) # Handle the final "else" ex = esc(inlineanonymous(length(operation) > 1 ? operation[2] : operation[1], N)) # Make the nested if statements for i = N-1:-1:1 ex = Expr(:if, esc(inlineanonymous(condition,i)), esc(inlineanonymous(operation[1],i)), ex) end ex end ## Utilities # Simplify expressions like :(d->3:size(A,d)-3) given an explicit value for d function inlineanonymous(ex::Expr, val) if ex.head != :-> error("Not an anonymous function") end if !isa(ex.args[1], Symbol) error("Not a single-argument anonymous function") end sym = ex.args[1] ex = ex.args[2] exout = lreplace(ex, sym, val) exout = poplinenum(exout) exprresolve(exout) end # Given :i and 3, this generates :i_3 inlineanonymous(base::Symbol, ext) = symbol(string(base)*"_"*string(ext)) # Replace a symbol by a value or a "coded" symbol # E.g., for d = 3, # lreplace(:d, :d, 3) -> 3 # lreplace(:i_d, :d, 3) -> :i_3 # lreplace(:i_{d-1}, :d, 3) -> :i_2 # This follows LaTeX notation. lreplace(ex, sym::Symbol, val) = lreplace!(copy(ex), sym, val, Regex("_"*string(sym)*"(\$|(?=_))")) lreplace!(arg, sym::Symbol, val, r) = arg function lreplace!(s::Symbol, sym::Symbol, val, r::Regex) if (s == sym) return val end symbol(replace(string(s), r, "_"*string(val))) end function lreplace!(ex::Expr, sym::Symbol, val, r) # Curly-brace notation, which acts like parentheses if ex.head == :curly && length(ex.args) == 2 && isa(ex.args[1], Symbol) && endswith(string(ex.args[1]), "_") excurly = exprresolve(lreplace!(ex.args[2], sym, val, r)) if isa(excurly, Number) return symbol(string(ex.args[1])*string(excurly)) else ex.args[2] = excurly return ex end end for i in 1:length(ex.args) ex.args[i] = lreplace!(ex.args[i], sym, val, r) end ex end poplinenum(arg) = arg function poplinenum(ex::Expr) if ex.head == :block if length(ex.args) == 1 return ex.args[1] elseif length(ex.args) == 2 && ex.args[1].head == :line return ex.args[2] end end ex end exprresolve(arg) = arg function exprresolve(ex::Expr) for i = 1:length(ex.args) ex.args[i] = exprresolve(ex.args[i]) end # Handle simple arithmetic if ex.head == :call && in(ex.args[1], (:+, :-, :*, :/)) && all([isa(ex.args[i], Number) for i = 2:length(ex.args)]) return eval(ex) elseif ex.head == :call && (ex.args[1] == :+ || ex.args[1] == :-) && length(ex.args) == 3 && ex.args[3] == 0 # simplify x+0 and x-0 return ex.args[2] end # Resolve array references if ex.head == :ref && isa(ex.args[1], Array) for i = 2:length(ex.args) if !isa(ex.args[i], Real) return ex end end return ex.args[1][ex.args[2:end]...] end # Resolve conditionals if ex.head == :if try tf = eval(ex.args[1]) ex = tf?ex.args[2]:ex.args[3] catch end end ex end end