Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Sep 16, 2023
1 parent 383b8e5 commit eb7c322
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Polyester"
uuid = "f517fe37-dbe3-4b94-8317-1923a5111588"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.7.5"
version = "0.7.6"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
128 changes: 53 additions & 75 deletions src/closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,7 @@ function makestatic!(expr)
end
expr
end
function enclose(
exorig::Expr,
reserve_per,
minbatchsize,
per::Symbol,
threadlocal_tuple,
stride,
mod,
)
function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, stride, mod)
Meta.isexpr(exorig, :for, 2) ||
throw(ArgumentError("Expression invalid; should be a for loop."))
ex = copy(exorig)
Expand All @@ -246,7 +238,8 @@ function enclose(
innerloop = Symbol("##inner##loop##")
rcombiner = Symbol("##split##recombined##")
threadlocal_var = Symbol("threadlocal")

#FIXME: don't do this?
per = stride ? :thread : per
# arguments = Symbol[]#loop_offs, loop_step]
arguments = Symbol[innerloop, rcombiner]#loop_offs, loop_step]
defined = Dict{Symbol,Symbol}(loop_offs => loop_offs, loop_step => loop_step)
Expand Down Expand Up @@ -275,7 +268,7 @@ function enclose(
quote
# for $(firstloop.args[1]) in
var"##outer##"::Int = Int($loopstart)::Int
while $loopstart <= $loop_stop
while var"##outer##" <= $loop_stop
for var"##inner##" in $innerloop
$fla1 = $combine($rcombiner, var"##inner##", var"##outer##")
$body
Expand All @@ -291,7 +284,7 @@ function enclose(
quote
# for $(firstloop.args[1]) in
var"##outer##"::Int = Int($loopstart)::Int
while $loopstart <= $loop_stop
while var"##outer##" <= $loop_stop
for var"##inner##" in $innerloop
$fla1 = $combine($rcombiner, var"##inner##", var"##outer##")
$body
Expand Down Expand Up @@ -325,11 +318,14 @@ function enclose(
Symbol("##NUM#THREADS##")
end

iter_len_def = quote
$(esc(innerloop)), $loop_sym, $(esc(rcombiner)) = $splitloop($(esc(makestatic!(loop))))
$iter_leng = $static_length($loop_sym)
end

q = quote
var"##NUM#THREADS#TO#USE##" = $num_thread_expr
$(esc(innerloop)), $loop_sym, $(esc(rcombiner)) =
$splitloop($(esc(makestatic!(loop))))
$iter_leng = $static_length($loop_sym)
$(stride ? nothing : iter_len_def)
$loop_step = $static_step($loop_sym)
$loop_offs = $static_first($loop_sym)
end
Expand Down Expand Up @@ -392,6 +388,7 @@ function enclose(
local var"##STEP##" = $(stride ? :($loop_step * Threads.nthreads()) : loop_step)
local $loopstart = $loop_start_expr
local $loop_stop = $loop_stop_expr
# $(stride ? :(@show $loopstart, $loop_stop) : nothing)
$threadlocal_get
@inbounds begin
$excomb
Expand Down Expand Up @@ -430,7 +427,11 @@ function enclose(
push!(q.args, batchcall)
quote
var"##NUM#THREADS##" = $(Threads.nthreads())
if var"##NUM#THREADS##" == 1
$(stride ? iter_len_def : nothing)
if (
stride ? :((var"##NUM#THREADS##" == 1) || (var"##NUM#THREADS##" > $iter_leng)) :
(var"##NUM#THREADS##" == 1)
)
single_thread_result = begin
$(esc(threadlocal_init_single)) # Initialize threadlocal storage
$(esc(q_single))
Expand Down Expand Up @@ -496,11 +497,17 @@ This may be better for load balancing if iterations close to each other take a s
`stride=false` is the default.
"""
macro batch(ex)
enclose(macroexpand(__module__, ex), 0, 1, :core, (Symbol(""), :Any), false, __module__)
enclose(
macroexpand(__module__, ex),
1,
:unspecified,
(Symbol(""), :Any),
false,
__module__,
)
end
function interpret_kwarg(
arg,
reserve_per = 0,
minbatch = 1,
per = :unspecified,
threadlocal = (Symbol(""), :Any),
Expand All @@ -509,6 +516,7 @@ function interpret_kwarg(
a = arg.args[1]
v = arg.args[2]
if a === :reserve
@warn "reserve has been deprecated"
@assert v 0
reserve_per = v
elseif a === :minbatch
Expand All @@ -522,72 +530,42 @@ function interpret_kwarg(
else
threadlocal = (v, :Any)
end
elseif a === :stride
stride = v::Bool
else
throw(ArgumentError("kwarg $(a) not recognized."))
end
reserve_per, minbatch, per, threadlocal, stride
minbatch, per, threadlocal, stride
end
macro batch(arg1, ex)
reserve, minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
per = per === :unspecified ? :core : per
enclose(
macroexpand(__module__, ex),
reserve,
minbatch,
per,
threadlocal,
stride,
__module__,
)
minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
per = per === :unspecified ? (stride ? :thread : :core) : per
enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__)
end
macro batch(arg1, arg2, ex)
reserve, minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
reserve, minbatch, per, threadlocal, stride =
interpret_kwarg(arg2, reserve, minbatch, per, threadlocal, stride)
per = per === :unspecified ? :core : per
enclose(
macroexpand(__module__, ex),
reserve,
minbatch,
per,
threadlocal,
stride,
__module__,
)
minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
minbatch, per, threadlocal, stride =
interpret_kwarg(arg2, minbatch, per, threadlocal, stride)
per = per === :unspecified ? (stride ? :thread : :core) : per
enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__)
end
macro batch(arg1, arg2, arg3, ex)
reserve, minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
reserve, minbatch, per, threadlocal, stride =
interpret_kwarg(arg2, reserve, minbatch, per, threadlocal, stride)
reserve, minbatch, per, threadlocal, stride =
interpret_kwarg(arg3, reserve, minbatch, per, threadlocal, stride)
per = per === :unspecified ? :core : per
enclose(
macroexpand(__module__, ex),
reserve,
minbatch,
per,
threadlocal,
stride,
__module__,
)
minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
minbatch, per, threadlocal, stride =
interpret_kwarg(arg2, minbatch, per, threadlocal, stride)
minbatch, per, threadlocal, stride =
interpret_kwarg(arg3, minbatch, per, threadlocal, stride)
per = per === :unspecified ? (stride ? :thread : :core) : per
enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__)
end
macro batch(arg1, arg2, arg3, arg4, ex)
reserve, minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
reserve, minbatch, per, threadlocal, stride =
interpret_kwarg(arg2, reserve, minbatch, per, threadlocal, stride)
reserve, minbatch, per, threadlocal, stride =
interpret_kwarg(arg3, reserve, minbatch, per, threadlocal, stride)
reserve, minbatch, per, threadlocal, stride =
interpret_kwarg(arg3, reserve, minbatch, per, threadlocal, stride)
per = per === :unspecified ? :core : per
enclose(
macroexpand(__module__, ex),
reserve,
minbatch,
per,
threadlocal,
stride,
__module__,
)
minbatch, per, threadlocal, stride = interpret_kwarg(arg1)
minbatch, per, threadlocal, stride =
interpret_kwarg(arg2, minbatch, per, threadlocal, stride)
minbatch, per, threadlocal, stride =
interpret_kwarg(arg3, minbatch, per, threadlocal, stride)
minbatch, per, threadlocal, stride =
interpret_kwarg(arg3, minbatch, per, threadlocal, stride)
per = per === :unspecified ? (stride ? :thread : :core) : per
enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__)
end
20 changes: 18 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ function bsin!(y, x, r = eachindex(y, x))
end
return y
end
function bsin_stride!(y, x, r = eachindex(y, x))
@batch stride = true for i r
y[i] = sin(x[i])
end
return y
end
function bcos!(y, x)
@batch per = core for i eachindex(y, x)
local cxᵢ
Expand Down Expand Up @@ -99,6 +105,12 @@ function issue25!(dest, x, y)
end
dest
end
function issue25_but_with_strides!(dest, x, y)
@batch stride = true for (i, j) Iterators.product(eachindex(x), eachindex(y))
dest[i, j] = x[i, begin] * y[j, end]
end
dest
end


@testset "Range Map" begin
Expand Down Expand Up @@ -140,6 +152,8 @@ end
y = similar(x)
z = similar(y)
@test bsin!(y, x) == (z .= sin.(x))
fill!(y, NaN)
@test bsin_stride!(y, x) == z
@test bcos!(y, x) == (z .= cos.(x))
@views z[1:3:length(x)] .= sin.(x[1:3:length(x)])
@test bsin!(y, x, 1:3:length(x)) == z
Expand Down Expand Up @@ -168,6 +182,8 @@ end
dest0 = similar(dest1)
# TODO: don't only thread outer
@test issue25!(dest0, x, y) dest1
fill!(dest0, NaN)
@test issue25_but_with_strides!(dest0, x, y) dest1
end
end

Expand Down Expand Up @@ -522,7 +538,7 @@ end
num_threads = min(Threads.nthreads(), sys_threads)

function issue30_set!(dst)
@batch per=thread for i in eachindex(dst)
@batch per = thread for i in eachindex(dst)
dst[i] = Threads.threadid()
end
return dst
Expand All @@ -536,7 +552,7 @@ end
end

function issue30_throw!(dst)
@batch per=thread for i in eachindex(dst)
@batch per = thread for i in eachindex(dst)
dst[i] = Threads.threadid()
if i > 1
throw(DomainError("expected error"))
Expand Down

0 comments on commit eb7c322

Please sign in to comment.