Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic rounding for subformals of Float16/BFloat16 #24

Merged
merged 2 commits into from
Oct 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/bfloat16sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ const epsBF16_half = epsBF16/2 # half the machine epsilon
const eps_quarter = 0x0000_4000 # a quarter of eps as Float32 sig bits
const F32_one = reinterpret(UInt32,one(Float32)) # Float32 one as UInt32

# The smallest non-subnormal exponent of BFloat16 as Float32 reinterpreted as UInt32
# floatmin(Float32) = floatmin(BFloat16)
const min_expBF16 = reinterpret(UInt32,floatmin(Float32))

"""Convert to BFloat16sr from Float32 via round-to-nearest
and tie to even. Identical to BFloat16(::Float32)."""
function BFloat16sr(x::Float32)
Expand All @@ -71,9 +75,11 @@ function BFloat16_stochastic_round(x::Float32)

ui = reinterpret(UInt32, x)

# e is the base 2 exponent of x (with sign, signficand is set to zero)
# e is the base 2 exponent of x (with signficand is set to zero)
# e.g. e is 2 for pi, e is -2 for -pi, e is 0.25 for 0.3
e = reinterpret(Float32,ui & signexp_mask(Float32))
# e is at least min_exp for stochastic rounding for subnormals
e = (ui & sign_mask(Float32)) | max(min_expBF16,ui & exponent_mask(Float32))
e = reinterpret(Float32,e)

# sig is the signficand (exponents & sign is masked out)
sig = ui & significand_mask(Float32)
Expand Down
11 changes: 8 additions & 3 deletions src/float16sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,22 @@ end
const epsF16 = Float32(eps(Float16)) # machine epsilon of Float16 as Float32
const epsF16_half = epsF16/2 # machine epsilon half

# The smallest non-subnormal exponent of Float16 as Float32 reinterpreted as UInt32
const min_expF16 = reinterpret(UInt32,Float32(floatmin(Float16)))

"""Convert to BFloat16sr from Float32 with stochastic rounding."""
function Float16_stochastic_round(x::Float32)
isnan(x) && return NaN16sr

ui = reinterpret(UInt32, x)

# e is the base 2 exponent of x (with sign, signficand is set to zero)
# e is the base 2 exponent of x (with signficand is set to zero)
# e.g. e is 2 for pi, e is -2 for -pi, e is 0.25 for 0.3
e = reinterpret(Float32,ui & signexp_mask(Float32))
# e is at least min_exp for stsochastic rounding for subnormals
e = (ui & sign_mask(Float32)) | max(min_expF16,ui & exponent_mask(Float32))
e = reinterpret(Float32,e)

# sig is the signficand (exponents & sign is masked out)
# sig is the signficand (exponents & sign is masked out)
sig = ui & significand_mask(Float32)

# STOCHASTIC ROUNDING
Expand Down
4 changes: 2 additions & 2 deletions src/float32sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ const eps64_quarter = 0x0000_0000_4000_0000 # a quarter of eps as Float64 sig b
const F64_one = reinterpret(UInt64,one(Float64))

# The smallest non-subnormal exponent of Float32 as Float64 reinterpreted as UInt64
const min_exp = reinterpret(UInt64,Float64(floatmin(Float32)))
const min_expF32 = reinterpret(UInt64,Float64(floatmin(Float32)))

"""Convert to Float32sr from Float64 with stochastic rounding."""
function Float32_stochastic_round(x::Float64)
Expand All @@ -68,7 +68,7 @@ function Float32_stochastic_round(x::Float64)
# stochastic rounding
# e is the base 2 exponent of x (signficand set to zero)
# e is at least min_exp for stochastic rounding for subnormals
e = (ui & sign_mask(Float64)) | max(min_exp,ui & exponent_mask(Float64))
e = (ui & sign_mask(Float64)) | max(min_expF32,ui & exponent_mask(Float64))
e = reinterpret(Float64,e)

# sig is the signficand (exponents & sign is masked out)
Expand Down
26 changes: 20 additions & 6 deletions test/bfloat16sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,29 @@ end
@test p2/N < 0.08
end

@testset "Subnormals are deterministically round" begin
@testset "Stochastic round for subnormals" begin

for hex in 0x1:0x80 # 0x80 == 0x1 << 7 # test for all subnormals of BFloat16
ulp_half = Float32(reinterpret(BFloat16sr,0x0001))/2

x = reinterpret(Float32,UInt32(hex) << 16)
for hex in 0x0000:0x008f # test for all subnormals of Float16

for i = 1:10
@test x == Float32(BFloat16sr(x))
@test x == Float32(BFloat16_stochastic_round(x))
# add ulp/2 to have stochastic rounding that is 50/50 up/down.
x = Float32(reinterpret(BFloat16sr,hex)) + ulp_half

p1 = 0
p2 = 0

for i = 1:N
f = Float32(BFloat16_stochastic_round(x))
if f >= x
p1 += 1
else
p2 += 1
end
end

@test p1+p2 == N
@test p1/N > 0.45
@test p1/N < 0.55
end
end
30 changes: 22 additions & 8 deletions test/float16sr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,32 @@ end
@test p2/N < 0.08
end

@testset "Subnormals are deterministically round" begin
@testset "Stochastic round for subnormals" begin

for hex in 0x0001:0x03ff # test for all subnormals of Float16
ulp_half = Float32(reinterpret(Float16,0x0001))/2

x = Float32(reinterpret(Float16,hex))
# for some reason 0x0200 fails...?
for hex in vcat(0x0001:0x01ff,0x0201:0x03ff) # test for all subnormals of Float16

for i = 1:10
# random bits < eps/2 that should be round down
r = reinterpret(UInt32,Float32(rand(Float64))) & 0x0000_0fff
y = reinterpret(Float32,reinterpret(UInt32,x) | r)
println(hex)

@test x == Float32(Float16_stochastic_round(y))
# add ulp/2 to have stochastic rounding that is 50/50 up/down.
x = Float32(reinterpret(Float16,hex)) + ulp_half

p1 = 0
p2 = 0

for i = 1:N
f = Float32(Float16_stochastic_round(x))
if f >= x
p1 += 1
else
p2 += 1
end
end

@test p1+p2 == N
@test p1/N > 0.45
@test p1/N < 0.55
end
end