Skip to content

Commit

Permalink
Added SSE gaussian f16, some fixes for u8
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Aug 4, 2024
1 parent a526b08 commit 4a96535
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 83 deletions.
37 changes: 30 additions & 7 deletions src/lib/fast_gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ use crate::neon::{
fast_gaussian_horizontal_pass_neon_u8, fast_gaussian_vertical_pass_neon_f16,
fast_gaussian_vertical_pass_neon_f32, fast_gaussian_vertical_pass_neon_u8,
};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
all(target_feature = "sse4.1", target_feature = "f16c")
))]
use crate::sse::{fast_gaussian_horizontal_pass_sse_f16, fast_gaussian_vertical_pass_sse_f16};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "sse4.1"
))]
use crate::sse::{fast_gaussian_horizontal_pass_sse_f32, fast_gaussian_vertical_pass_sse_f32};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "sse4.1"
Expand All @@ -51,11 +61,6 @@ use crate::threading_policy::ThreadingPolicy;
use crate::to_storage::ToStorage;
use crate::unsafe_slice::UnsafeSlice;
use crate::{clamp_edge, reflect_101, EdgeMode};
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "sse4.1"
))]
use crate::sse::{fast_gaussian_vertical_pass_sse_f32, fast_gaussian_horizontal_pass_sse_f32};

const BASE_RADIUS_I64_CUTOFF: u32 = 180;

Expand Down Expand Up @@ -652,8 +657,13 @@ fn fast_gaussian_impl<
target_feature = "sse4.1"
))]
{
_dispatcher_vertical = fast_gaussian_vertical_pass_sse_f32::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_horizontal = fast_gaussian_horizontal_pass_sse_f32::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_vertical =
fast_gaussian_vertical_pass_sse_f32::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_horizontal = fast_gaussian_horizontal_pass_sse_f32::<
T,
CHANNEL_CONFIGURATION,
EDGE_MODE,
>;
}
} else if std::any::type_name::<T>() == "f16"
|| std::any::type_name::<T>() == "half::f16"
Expand All @@ -669,6 +679,19 @@ fn fast_gaussian_impl<
EDGE_MODE,
>;
}
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
all(target_feature = "sse4.1", target_feature = "f16c")
))]
{
_dispatcher_vertical =
fast_gaussian_vertical_pass_sse_f16::<T, CHANNEL_CONFIGURATION, EDGE_MODE>;
_dispatcher_horizontal = fast_gaussian_horizontal_pass_sse_f16::<
T,
CHANNEL_CONFIGURATION,
EDGE_MODE,
>;
}
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
Expand Down
4 changes: 2 additions & 2 deletions src/lib/gaussian/avx/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod filter_vertical_f32;
mod utils;
mod vertical;
mod vertical_f32;
mod filter_vertical_f32;

pub use filter_vertical_f32::gaussian_blur_vertical_pass_filter_f32_avx;
pub use vertical::gaussian_blur_vertical_pass_impl_avx;
pub use vertical_f32::gaussian_blur_vertical_pass_impl_f32_avx;
pub use filter_vertical_f32::gaussian_blur_vertical_pass_filter_f32_avx;
10 changes: 5 additions & 5 deletions src/lib/gaussian/gaussian_kernel_filter_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx2"
))]
use crate::gaussian::avx::gaussian_blur_vertical_pass_filter_f32_avx;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
use crate::gaussian::gauss_neon::*;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
Expand All @@ -51,11 +56,6 @@ use crate::to_storage::ToStorage;
use crate::unsafe_slice::UnsafeSlice;
use num_traits::{AsPrimitive, FromPrimitive};
use rayon::ThreadPool;
#[cfg(all(
any(target_arch = "x86_64", target_arch = "x86"),
target_feature = "avx2"
))]
use crate::gaussian::avx::gaussian_blur_vertical_pass_filter_f32_avx;

pub(crate) fn gaussian_blur_vertical_pass_edge_clip_dispatch<
T: FromPrimitive + Default + Send + Sync,
Expand Down
12 changes: 7 additions & 5 deletions src/lib/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,29 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

mod f16_utils;
mod fast_gaussian;
mod fast_gaussian_f16;
mod fast_gaussian_f32;
mod fast_gaussian_next;
mod fast_gaussian_next_f32;
mod stack_blur_f16;
mod stack_blur_f32;
mod stack_blur_i32;
mod stack_blur_i64;
mod utils;
mod f16_utils;
mod stack_blur_f16;
mod fast_gaussian_f16;

pub use fast_gaussian::*;
pub use fast_gaussian_f16::{
fast_gaussian_horizontal_pass_neon_f16, fast_gaussian_vertical_pass_neon_f16,
};
pub use fast_gaussian_f32::fast_gaussian_horizontal_pass_neon_f32;
pub use fast_gaussian_f32::fast_gaussian_vertical_pass_neon_f32;
pub use fast_gaussian_next::*;
pub use fast_gaussian_next_f32::fast_gaussian_next_horizontal_pass_neon_f32;
pub use fast_gaussian_next_f32::fast_gaussian_next_vertical_pass_neon_f32;
pub use stack_blur_f16::stack_blur_pass_neon_f16;
pub use stack_blur_f32::stack_blur_pass_neon_f32;
pub use stack_blur_i32::*;
pub use stack_blur_i64::stack_blur_pass_neon_i64;
pub use utils::*;
pub use stack_blur_f16::stack_blur_pass_neon_f16;
pub use fast_gaussian_f16::{fast_gaussian_vertical_pass_neon_f16, fast_gaussian_horizontal_pass_neon_f16};
4 changes: 2 additions & 2 deletions src/lib/neon/stack_blur_f16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::arch::aarch64::*;
use half::f16;
use crate::neon::{load_f32_f16, store_f32_f16};
use crate::stack_blur::StackBlurPass;
use crate::unsafe_slice::UnsafeSlice;
use half::f16;
use std::arch::aarch64::*;

pub fn stack_blur_pass_neon_f16<const COMPONENTS: usize>(
pixels: &UnsafeSlice<f16>,
Expand Down
64 changes: 28 additions & 36 deletions src/lib/sse/fast_gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,16 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use crate::mul_table::{MUL_TABLE_DOUBLE, SHR_TABLE_DOUBLE};
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

use crate::reflect_101;
use crate::reflect_index;
use crate::sse::_mm_packus_epi64;
use crate::sse::utils::load_u8_s32_fast;
use crate::unsafe_slice::UnsafeSlice;
use crate::{clamp_edge, EdgeMode};
use erydanos::_mm_mul_epi64;
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

pub fn fast_gaussian_horizontal_pass_sse_u8<
T,
Expand All @@ -58,10 +56,9 @@ pub fn fast_gaussian_horizontal_pass_sse_u8<

let radius_64 = radius as i64;
let width_wide = width as i64;
let mul_value = MUL_TABLE_DOUBLE[radius as usize];
let shr_value = SHR_TABLE_DOUBLE[radius as usize];
let v_mul_value = unsafe { _mm_set1_epi64x(mul_value as i64) };
let v_shr_value = unsafe { _mm_setr_epi32(shr_value, 0, 0, 0) };
const ROUNDING_FLAGS: i32 = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC;

let v_weight = unsafe { _mm_set1_ps(1f32 / (radius as f32 * radius as f32)) };
for y in start..std::cmp::min(height, end) {
let mut diffs = unsafe { _mm_setzero_si128() };
let mut summs = unsafe { _mm_set1_epi32(initial_sum) };
Expand All @@ -73,16 +70,13 @@ pub fn fast_gaussian_horizontal_pass_sse_u8<
if x >= 0 {
let current_px = ((std::cmp::max(x, 0) as u32) * CHANNELS_COUNT as u32) as usize;

let hi_a = unsafe { _mm_unpackhi_epi32(summs, _mm_setzero_si128()) };
let lo_b = unsafe { _mm_unpacklo_epi32(summs, _mm_setzero_si128()) };
let blurred_hi =
unsafe { _mm_srl_epi64(_mm_mul_epi64(hi_a, v_mul_value), v_shr_value) };
let blurred_lo =
unsafe { _mm_srl_epi64(_mm_mul_epi64(lo_b, v_mul_value), v_shr_value) };
let prepared_px_s32 = unsafe { _mm_packus_epi64(blurred_lo, blurred_hi) };
let prepared_u16 = unsafe { _mm_packus_epi32(prepared_px_s32, prepared_px_s32) };
let prepared_u8 = unsafe { _mm_packus_epi16(prepared_u16, prepared_u16) };
let pixel = unsafe { _mm_extract_epi32::<0>(prepared_u8) };
let pixel_f32 = unsafe {
_mm_round_ps::<ROUNDING_FLAGS>(_mm_mul_ps(_mm_cvtepi32_ps(summs), v_weight))
};
let pixel_u32 = unsafe { _mm_cvtps_epi32(pixel_f32) };
let pixel_u16 = unsafe { _mm_packus_epi32(pixel_u32, pixel_u32) };
let pixel_u8 = unsafe { _mm_packus_epi16(pixel_u16, pixel_u16) };
let pixel = unsafe { _mm_extract_epi32::<0>(pixel_u8) };

let bytes_offset = current_y + current_px;

Expand Down Expand Up @@ -161,10 +155,11 @@ pub(crate) fn fast_gaussian_vertical_pass_sse_u8<
let height_wide = height as i64;

let radius_64 = radius as i64;
let mul_value = MUL_TABLE_DOUBLE[radius as usize];
let shr_value = SHR_TABLE_DOUBLE[radius as usize];
let v_mul_value = unsafe { _mm_set1_epi64x(mul_value as i64) };
let v_shr_value = unsafe { _mm_setr_epi32(shr_value, 0, 0, 0) };

const ROUNDING_FLAGS: i32 = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC;

let v_weight = unsafe { _mm_set1_ps(1f32 / (radius as f32 * radius as f32)) };

for x in start..std::cmp::min(width, end) {
let mut diffs = unsafe { _mm_setzero_si128() };
let mut summs = unsafe { _mm_set1_epi32(initial_sum) };
Expand All @@ -176,16 +171,13 @@ pub(crate) fn fast_gaussian_vertical_pass_sse_u8<
if y >= 0 {
let current_px = ((std::cmp::max(x, 0)) * CHANNELS_COUNT as u32) as usize;

let hi_a = unsafe { _mm_unpackhi_epi32(summs, _mm_setzero_si128()) };
let lo_b = unsafe { _mm_unpacklo_epi32(summs, _mm_setzero_si128()) };
let blurred_hi =
unsafe { _mm_srl_epi64(_mm_mul_epi64(hi_a, v_mul_value), v_shr_value) };
let blurred_lo =
unsafe { _mm_srl_epi64(_mm_mul_epi64(lo_b, v_mul_value), v_shr_value) };
let prepared_px_s32 = unsafe { _mm_packus_epi64(blurred_lo, blurred_hi) };
let prepared_u16 = unsafe { _mm_packus_epi32(prepared_px_s32, prepared_px_s32) };
let prepared_u8 = unsafe { _mm_packus_epi16(prepared_u16, prepared_u16) };
let pixel = unsafe { _mm_extract_epi32::<0>(prepared_u8) };
let pixel_f32 = unsafe {
_mm_round_ps::<ROUNDING_FLAGS>(_mm_mul_ps(_mm_cvtepi32_ps(summs), v_weight))
};
let pixel_u32 = unsafe { _mm_cvtps_epi32(pixel_f32) };
let pixel_u16 = unsafe { _mm_packus_epi32(pixel_u32, pixel_u32) };
let pixel_u8 = unsafe { _mm_packus_epi16(pixel_u16, pixel_u16) };
let pixel = unsafe { _mm_extract_epi32::<0>(pixel_u8) };

let bytes_offset = current_y + current_px;

Expand All @@ -212,7 +204,7 @@ pub(crate) fn fast_gaussian_vertical_pass_sse_u8<
let mut d_stored = unsafe { _mm_loadu_si128(d_buf_ptr as *const __m128i) };
d_stored = unsafe { _mm_slli_epi32::<1>(d_stored) };

let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as * mut i32 };
let buf_ptr = unsafe { buffer.as_mut_ptr().add(arr_index) as *mut i32 };
let a_stored = unsafe { _mm_loadu_si128(buf_ptr as *const __m128i) };

diffs = unsafe { _mm_add_epi32(diffs, _mm_sub_epi32(a_stored, d_stored)) };
Expand Down
Loading

0 comments on commit 4a96535

Please sign in to comment.