Skip to content

Commit

Permalink
Improving and bugfixes f16
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Aug 4, 2024
1 parent ef96cec commit 0ad48e7
Show file tree
Hide file tree
Showing 11 changed files with 936 additions and 18 deletions.
36 changes: 29 additions & 7 deletions src/lib/fast_gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ use colorutils_rs::planar_to_linear::plane_to_linear;
use colorutils_rs::{
linear_to_rgb, linear_to_rgba, rgb_to_linear, rgba_to_linear, TransferFunction,
};
use half::f16;
use num_traits::cast::FromPrimitive;
use num_traits::{AsPrimitive, Float};

use crate::channels_configuration::FastBlurChannels;
use crate::edge_mode::reflect_index;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
use crate::neon::{
fast_gaussian_horizontal_pass_neon_f32, fast_gaussian_horizontal_pass_neon_u8,
fast_gaussian_horizontal_pass_neon_f16, fast_gaussian_horizontal_pass_neon_f32,
fast_gaussian_horizontal_pass_neon_u8, fast_gaussian_vertical_pass_neon_f16,
fast_gaussian_vertical_pass_neon_f32, fast_gaussian_vertical_pass_neon_u8,
};
#[cfg(all(
Expand Down Expand Up @@ -617,6 +619,7 @@ fn fast_gaussian_impl<
if std::any::type_name::<T>() == "f32"
|| std::any::type_name::<T>() == "f16"
|| std::any::type_name::<T>() == "half::f16"
|| std::any::type_name::<T>() == "half::binary16::f16"
{
_dispatcher_vertical = if BASE_RADIUS_I64_CUTOFF > radius {
fast_gaussian_vertical_pass::<T, f32, f32, CHANNEL_CONFIGURATION, EDGE_MODE>
Expand All @@ -628,12 +631,31 @@ fn fast_gaussian_impl<
} else {
fast_gaussian_horizontal_pass::<T, f64, f64, CHANNEL_CONFIGURATION, EDGE_MODE>
};
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
if std::any::type_name::<T>() == "f32" {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
_dispatcher_vertical =
fast_gaussian_vertical_pass_neon_f32::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_horizontal = fast_gaussian_horizontal_pass_neon_f32::<
T,
CHANNEL_CONFIGURATION,
EDGE_MODE,
>;
}
} else if std::any::type_name::<T>() == "f16"
|| std::any::type_name::<T>() == "half::f16"
|| std::any::type_name::<T>() == "half::binary16::f16"
{
_dispatcher_vertical =
fast_gaussian_vertical_pass_neon_f32::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_horizontal =
fast_gaussian_horizontal_pass_neon_f32::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
_dispatcher_vertical =
fast_gaussian_vertical_pass_neon_f16::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_horizontal = fast_gaussian_horizontal_pass_neon_f16::<
T,
CHANNEL_CONFIGURATION,
EDGE_MODE,
>;
}
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
Expand Down Expand Up @@ -935,7 +957,7 @@ pub fn fast_gaussian_in_linear(
/// # Panics
/// Panic is stride/width/height/channel configuration do not match provided
pub fn fast_gaussian_f16(
bytes: &mut [u16],
bytes: &mut [f16],
width: u32,
height: u32,
radius: u32,
Expand Down
1 change: 1 addition & 0 deletions src/lib/fast_gaussian_next.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ fn fast_gaussian_next_impl<
if std::any::type_name::<T>() == "f32"
|| std::any::type_name::<T>() == "f16"
|| std::any::type_name::<T>() == "half::f16"
|| std::any::type_name::<T>() == "half::binary16::f16"
{
_dispatcher_vertical = if BASE_RADIUS_I64_CUTOFF > radius {
fast_gaussian_next_vertical_pass::<T, f32, f32, CHANNEL_CONFIGURATION, EDGE_MODE>
Expand Down
2 changes: 2 additions & 0 deletions src/lib/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ mod neon;
))]
mod sse;
mod stack_blur;
mod stack_blur_f16;
mod stack_blur_f32;
mod stack_blur_linear;
mod threading_policy;
Expand Down Expand Up @@ -83,6 +84,7 @@ pub use r#box::tent_blur_f32;
pub use r#box::tent_blur_in_linear;
pub use r#box::tent_blur_u16;
pub use stack_blur::stack_blur;
pub use stack_blur_f16::stack_blur_f16;
pub use stack_blur_f32::stack_blur_f32;
pub use stack_blur_linear::stack_blur_in_linear;
pub use threading_policy::ThreadingPolicy;
91 changes: 91 additions & 0 deletions src/lib/neon/f16_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) Radzivon Bartoshyk. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
use std::arch::aarch64::*;
use std::arch::asm;

/// Provides basic support for f16

#[derive(Debug, Clone, Copy)]
#[allow(non_camel_case_types)]
#[allow(dead_code)]
pub struct x_float16x4_t(pub(crate) uint16x4_t);

#[derive(Debug, Clone, Copy)]
#[allow(non_camel_case_types)]
#[allow(dead_code)]
pub struct x_float16x8_t(pub(crate) uint16x8_t);

#[inline]
pub unsafe fn xvld_f16(ptr: *const half::f16) -> x_float16x4_t {
let store: uint16x4_t = vld1_u16(std::mem::transmute(ptr));
std::mem::transmute(store)
}

#[inline]
pub unsafe fn xreinterpret_u16_f16(x: x_float16x4_t) -> uint16x4_t {
std::mem::transmute(x)
}

#[inline]
pub unsafe fn xreinterpret_f16_u16(x: uint16x4_t) -> x_float16x4_t {
std::mem::transmute(x)
}

// #[inline]
// pub unsafe fn xreinterpretq_f16_u16(x: uint16x8_t) -> x_float16x8_t {
// std::mem::transmute(x)
// }

#[inline]
pub unsafe fn xvcvt_f32_f16(x: x_float16x4_t) -> float32x4_t {
let src: uint16x4_t = xreinterpret_u16_f16(x);
let dst: float32x4_t;
asm!(
"fcvtl {0:v}.4s, {1:v}.4h",
out(vreg) dst,
in(vreg) src,
options(pure, nomem, nostack));
dst
}

#[inline]
pub(super) unsafe fn xvcvt_f16_f32(v: float32x4_t) -> x_float16x4_t {
let result: uint16x4_t;
asm!(
"fcvtn {0:v}.4h, {1:v}.4s",
out(vreg) result,
in(vreg) v,
options(pure, nomem, nostack));
xreinterpret_f16_u16(result)
}

#[inline]
pub unsafe fn xvst_f16(ptr: *const half::f16, x: x_float16x4_t) {
vst1_u16(std::mem::transmute(ptr), xreinterpret_u16_f16(x))
}
188 changes: 188 additions & 0 deletions src/lib/neon/fast_gaussian_f16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Copyright (c) Radzivon Bartoshyk. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::arch::aarch64::*;

use half::f16;

use crate::neon::{load_f32_f16, store_f32_f16};
use crate::unsafe_slice::UnsafeSlice;
use crate::{clamp_edge, reflect_101, reflect_index, EdgeMode};

pub fn fast_gaussian_vertical_pass_neon_f16<
T,
const CHANNELS_COUNT: usize,
const EDGE_MODE: usize,
>(
undef_bytes: &UnsafeSlice<T>,
stride: u32,
width: u32,
height: u32,
radius: u32,
start: u32,
end: u32,
) {
let edge_mode: EdgeMode = EDGE_MODE.into();
let bytes: &UnsafeSlice<'_, f16> = unsafe { std::mem::transmute(undef_bytes) };
let mut buffer: [[f32; 4]; 1024] = [[0f32; 4]; 1024];

let height_wide = height as i64;

let radius_64 = radius as i64;
let weight = 1.0f32 / ((radius as f32) * (radius as f32));
let f_weight = unsafe { vdupq_n_f32(weight) };
for x in start..std::cmp::min(width, end) {
let mut diffs: float32x4_t = unsafe { vdupq_n_f32(0f32) };
let mut summs: float32x4_t = unsafe { vdupq_n_f32(0f32) };

let start_y = 0 - 2 * radius as i64;
for y in start_y..height_wide {
let current_y = (y * (stride as i64)) as usize;

if y >= 0 {
let current_px = (std::cmp::max(x, 0)) as usize * CHANNELS_COUNT;

let prepared_px = unsafe { vmulq_f32(summs, f_weight) };

unsafe {
let dst_ptr = bytes.slice.as_ptr().add(current_y + current_px) as *mut f16;
store_f32_f16::<CHANNELS_COUNT>(dst_ptr, prepared_px);
}

let arr_index = ((y - radius_64) & 1023) as usize;
let d_arr_index = (y & 1023) as usize;

let d_buf_ptr = unsafe { buffer.as_mut_ptr().add(d_arr_index) as *mut f32 };
let mut d_stored = unsafe { vld1q_f32(d_buf_ptr) };
d_stored = unsafe { vmulq_n_f32(d_stored, 2f32) };

let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut f32 };
let a_stored = unsafe { vld1q_f32(buf_ptr) };

diffs = unsafe { vaddq_f32(diffs, vsubq_f32(a_stored, d_stored)) };
} else if y + radius_64 >= 0 {
let arr_index = (y & 1023) as usize;
let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut f32 };
let mut stored = unsafe { vld1q_f32(buf_ptr) };
stored = unsafe { vmulq_n_f32(stored, 2f32) };
diffs = unsafe { vsubq_f32(diffs, stored) };
}

let next_row_y =
clamp_edge!(edge_mode, y + radius_64, 0, height_wide - 1) * (stride as usize);
let next_row_x = x as usize * CHANNELS_COUNT;

let s_ptr = unsafe { bytes.slice.as_ptr().add(next_row_y + next_row_x) as *mut f16 };
let pixel_color = unsafe { load_f32_f16::<CHANNELS_COUNT>(s_ptr) };

let arr_index = ((y + radius_64) & 1023) as usize;
let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut f32 };

diffs = unsafe { vaddq_f32(diffs, pixel_color) };
summs = unsafe { vaddq_f32(summs, diffs) };
unsafe {
vst1q_f32(buf_ptr, pixel_color);
}
}
}
}

pub fn fast_gaussian_horizontal_pass_neon_f16<
T,
const CHANNELS_COUNT: usize,
const EDGE_MODE: usize,
>(
undef_bytes: &UnsafeSlice<T>,
stride: u32,
width: u32,
height: u32,
radius: u32,
start: u32,
end: u32,
) {
let edge_mode: EdgeMode = EDGE_MODE.into();
let bytes: &UnsafeSlice<'_, f16> = unsafe { std::mem::transmute(undef_bytes) };
let mut buffer: [[f32; 4]; 1024] = [[0f32; 4]; 1024];
let radius_64 = radius as i64;
let width_wide = width as i64;
let weight = 1.0f32 / ((radius as f32) * (radius as f32));
let f_weight = unsafe { vdupq_n_f32(weight) };
for y in start..std::cmp::min(height, end) {
let mut diffs: float32x4_t = unsafe { vdupq_n_f32(0f32) };
let mut summs: float32x4_t = unsafe { vdupq_n_f32(0f32) };

let current_y = ((y as i64) * (stride as i64)) as usize;

let start_x = 0 - 2 * radius_64;
for x in start_x..(width as i64) {
if x >= 0 {
let current_px = (std::cmp::max(x, 0) as u32) as usize * CHANNELS_COUNT;

let prepared_px = unsafe { vmulq_f32(summs, f_weight) };

unsafe {
let dst_ptr = bytes.slice.as_ptr().add(current_y + current_px) as *mut f16;
store_f32_f16::<CHANNELS_COUNT>(dst_ptr, prepared_px);
}

let arr_index = ((x - radius_64) & 1023) as usize;
let d_arr_index = (x & 1023) as usize;

let d_buf_ptr = unsafe { buffer.as_mut_ptr().add(d_arr_index) as *mut f32 };
let mut d_stored = unsafe { vld1q_f32(d_buf_ptr) };
d_stored = unsafe { vmulq_n_f32(d_stored, 2f32) };

let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut f32 };
let a_stored = unsafe { vld1q_f32(buf_ptr) };

diffs = unsafe { vaddq_f32(diffs, vsubq_f32(a_stored, d_stored)) };
} else if x + radius_64 >= 0 {
let arr_index = (x & 1023) as usize;
let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut f32 };
let mut stored = unsafe { vld1q_f32(buf_ptr) };
stored = unsafe { vmulq_n_f32(stored, 2f32) };
diffs = unsafe { vsubq_f32(diffs, stored) };
}

let next_row_y = (y as usize) * (stride as usize);
let next_row_x = clamp_edge!(edge_mode, x + radius_64, 0, width_wide - 1);
let next_row_px = next_row_x * CHANNELS_COUNT;

let s_ptr = unsafe { bytes.slice.as_ptr().add(next_row_y + next_row_px) as *mut f16 };
let pixel_color = unsafe { load_f32_f16::<CHANNELS_COUNT>(s_ptr) };

let arr_index = ((x + radius_64) & 1023) as usize;
let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut f32 };

diffs = unsafe { vaddq_f32(diffs, pixel_color) };
summs = unsafe { vaddq_f32(summs, diffs) };
unsafe {
vst1q_f32(buf_ptr, pixel_color);
}
}
}
}
5 changes: 5 additions & 0 deletions src/lib/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ mod stack_blur_f32;
mod stack_blur_i32;
mod stack_blur_i64;
mod utils;
mod f16_utils;
mod stack_blur_f16;
mod fast_gaussian_f16;

pub use fast_gaussian::*;
pub use fast_gaussian_f32::fast_gaussian_horizontal_pass_neon_f32;
Expand All @@ -44,3 +47,5 @@ pub use stack_blur_f32::stack_blur_pass_neon_f32;
pub use stack_blur_i32::*;
pub use stack_blur_i64::stack_blur_pass_neon_i64;
pub use utils::*;
pub use stack_blur_f16::stack_blur_pass_neon_f16;
pub use fast_gaussian_f16::{fast_gaussian_vertical_pass_neon_f16, fast_gaussian_horizontal_pass_neon_f16};
Loading

0 comments on commit 0ad48e7

Please sign in to comment.