diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4aba752..fef3906 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,3 +14,4 @@ jobs: uses: janosh/workflows/.github/workflows/npm-test-release.yml@main with: install-cmd: npm install --force + test-cmd: npm run test:unit diff --git a/src/lib/structure/StructureScene.svelte b/src/lib/structure/StructureScene.svelte index 253ea1b..11fd31a 100644 --- a/src/lib/structure/StructureScene.svelte +++ b/src/lib/structure/StructureScene.svelte @@ -51,7 +51,7 @@ export let active_site: Site | null = null export let precision: string = `.3~f` export let auto_rotate: number | boolean = 0 // auto rotate speed. set to 0 to disable auto rotation. - export let bond_radius: number | undefined = undefined + export let bond_radius: number | undefined = 0.05 export let bond_opacity: number = 0.5 export let bond_color: string = `#ffffff` // must be hex code for export let bonding_strategy: keyof typeof bonding_strategies = `nearest_neighbor` @@ -84,7 +84,7 @@ } // make bond thickness reactive to atom_radius unless bond_radius is set - $: bond_thickness = bond_radius ?? 0.1 * atom_radius + $: bond_thickness = bond_radius ?? 0.05 * atom_radius const gizmo_defaults: Partial> = { horizontalPlacement: `left`, size: 100, diff --git a/src/lib/structure/bonding.ts b/src/lib/structure/bonding.ts index 0c63f93..e022616 100644 --- a/src/lib/structure/bonding.ts +++ b/src/lib/structure/bonding.ts @@ -1,27 +1,30 @@ import type { BondPair, PymatgenStructure } from '$lib' -import { euclidean_dist } from '$lib' -// TODO add unit tests for these functions +export type BondingAlgo = typeof max_dist | typeof nearest_neighbor + export function max_dist( structure: PymatgenStructure, { max_bond_dist = 3, min_bond_dist = 0.4 } = {}, // in Angstroms ): BondPair[] { // finds all pairs of atoms within the max_bond_dist cutoff const bonds: BondPair[] = [] - const bond_set: Set = new Set() + const bond_set = new Set() + const max_bond_dist_sq = max_bond_dist ** 2 + const min_bond_dist_sq = min_bond_dist ** 2 - for (let idx = 0; idx < structure.sites.length; idx++) { - const { xyz } = structure.sites[idx] + for (let i = 0; i < structure.sites.length; i++) { + const { xyz: xyz1 } = structure.sites[i] - for (let idx_2 = idx + 1; idx_2 < structure.sites.length; idx_2++) { - const { xyz: xyz_2 } = structure.sites[idx_2] + for (let j = i + 1; j < structure.sites.length; j++) { + const { xyz: xyz2 } = structure.sites[j] - const dist = euclidean_dist(xyz, xyz_2) - if (dist < max_bond_dist && dist > min_bond_dist) { - const bond_key = [xyz, xyz_2].sort().toString() + const dist_sq = euclidean_dist_sq(xyz1, xyz2) + if (dist_sq <= max_bond_dist_sq && dist_sq >= min_bond_dist_sq) { + const dist = Math.sqrt(dist_sq) + const bond_key = `${i},${j}` if (!bond_set.has(bond_key)) { bond_set.add(bond_key) - bonds.push([xyz, xyz_2, idx, idx_2, dist]) + bonds.push([xyz1, xyz2, i, j, dist]) } } } @@ -34,29 +37,41 @@ export function nearest_neighbor( { scaling_factor = 1.2, min_bond_dist = 0.1 } = {}, // in Angstroms ): BondPair[] { // finds bonds to sites less than scaling_factor farther away than the nearest neighbor + const num_sites = structure.sites.length const bonds: BondPair[] = [] - const bond_set: Set = new Set() + const bond_set = new Set() + const min_bond_dist_sq = min_bond_dist ** 2 + + const nearest_distances = new Array(num_sites).fill(Infinity) + // First pass: find nearest neighbor distances for (let i = 0; i < num_sites; i++) { const { xyz: xyz1 } = structure.sites[i] - let min_dist = Infinity for (let j = i + 1; j < num_sites; j++) { const { xyz: xyz2 } = structure.sites[j] - const dist = euclidean_dist(xyz1, xyz2) + const dist_sq = euclidean_dist_sq(xyz1, xyz2) - if (dist > min_bond_dist && dist < min_dist) { - min_dist = dist + if (dist_sq >= min_bond_dist_sq) { + if (dist_sq < nearest_distances[i]) nearest_distances[i] = dist_sq + if (dist_sq < nearest_distances[j]) nearest_distances[j] = dist_sq } } + } + + // Second pass: add bonds within scaled distance + for (let i = 0; i < num_sites; i++) { + const { xyz: xyz1 } = structure.sites[i] + const max_dist_sq = nearest_distances[i] * scaling_factor ** 2 for (let j = i + 1; j < num_sites; j++) { const { xyz: xyz2 } = structure.sites[j] - const dist = euclidean_dist(xyz1, xyz2) + const dist_sq = euclidean_dist_sq(xyz1, xyz2) - if (dist <= min_dist * scaling_factor) { - const bond_key = [xyz1, xyz2].sort().toString() + if (dist_sq >= min_bond_dist_sq && dist_sq <= max_dist_sq) { + const dist = Math.sqrt(dist_sq) + const bond_key = `${i},${j}` if (!bond_set.has(bond_key)) { bond_set.add(bond_key) bonds.push([xyz1, xyz2, i, j, dist]) @@ -67,3 +82,9 @@ export function nearest_neighbor( return bonds } + +// redundant functionality-wise with euclidean_dist from $lib/math.ts but needed for performance +// makes bonding algos 2-3x faster +function euclidean_dist_sq(vec_a: number[], vec_b: number[]): number { + return vec_a.reduce((sum, _, i) => sum + (vec_a[i] - vec_b[i]) ** 2, 0) +} diff --git a/tests/unit/bonding.test.ts b/tests/unit/bonding.test.ts new file mode 100644 index 0000000..dc02de9 --- /dev/null +++ b/tests/unit/bonding.test.ts @@ -0,0 +1,194 @@ +import type { PymatgenStructure } from '$lib/structure' +import type { BondingAlgo } from '$lib/structure/bonding' +import { max_dist, nearest_neighbor } from '$lib/structure/bonding' +import { performance } from 'perf_hooks' +import { describe, expect, test } from 'vitest' + +const ci_max_time_multiplier = process.env.CI ? 5 : 1 + +// Function to generate a random structure +function make_rand_structure(numAtoms: number) { + return { + sites: Array.from({ length: numAtoms }, () => ({ + xyz: [Math.random() * 10, Math.random() * 10, Math.random() * 10], + })), + } as PymatgenStructure +} + +// Updated performance test function +function perf_test(func: BondingAlgo, atom_count: number, max_time: number) { + const run = () => { + const structure = make_rand_structure(atom_count) + const start = performance.now() + func(structure) + const end = performance.now() + return end - start + } + + const time1 = run() + const time2 = run() + const avg_time = (time1 + time2) / 2 + + expect( + avg_time, + `average run time: ${Math.ceil(avg_time)}, max expected: ${max_time * ci_max_time_multiplier}`, // Apply scaling factor + ).toBeLessThanOrEqual(max_time * ci_max_time_multiplier) +} + +describe(`Bonding Functions Performance Tests`, () => { + const bonding_functions = [ + { + func: max_dist, + max_times: [ + [10, 0.1], + [100, 1], + [1000, 40], + [5000, 1000], + ], + }, + { + func: nearest_neighbor, + max_times: [ + [10, 0.2], + [100, 3], + [1000, 50], + [5000, 1000], + ], + }, + ] + + for (const { func, max_times } of bonding_functions) { + for (const [atom_count, max_time] of max_times) { + test(`${func.name} performance for ${atom_count} atoms`, () => { + perf_test(func, atom_count, max_time) + }) + } + } +}) + +// Helper function to create a simple structure +const make_struct = (sites: number[][]): PymatgenStructure => ({ + sites: sites.map((xyz) => ({ xyz })), +}) + +describe(`max_dist function`, () => { + test(`should return correct bonds for a simple structure`, () => { + const structure = make_struct([ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ]) + const bonds = max_dist(structure, { + max_bond_dist: 1.5, + min_bond_dist: 0.5, + }) + expect(bonds).toHaveLength(6) + expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1]) + }) + + test(`should not return bonds shorter than min_bond_dist`, () => { + const structure = make_struct([ + [0, 0, 0], + [0.3, 0, 0], + ]) + const bonds = max_dist(structure, { max_bond_dist: 1, min_bond_dist: 0.5 }) + expect(bonds).toHaveLength(0) + }) + + test(`should not return bonds longer than max_bond_dist`, () => { + const structure = make_struct([ + [0, 0, 0], + [2, 0, 0], + ]) + const bonds = max_dist(structure, { + max_bond_dist: 1.5, + min_bond_dist: 0.5, + }) + expect(bonds).toHaveLength(0) + }) + + test(`should handle empty structures`, () => { + const structure = make_struct([]) + const bonds = max_dist(structure) + expect(bonds).toHaveLength(0) + }) +}) + +describe(`nearest_neighbor function`, () => { + test(`should return correct bonds for a simple structure`, () => { + const structure = make_struct([ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [2, 0, 0], + ]) + const bonds = nearest_neighbor(structure, { + scaling_factor: 1.1, + min_bond_dist: 0.5, + }) + expect(bonds).toHaveLength(4) + expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1]) + }) + + test(`should not return bonds shorter than min_bond_dist`, () => { + const structure = make_struct([ + [0, 0, 0], + [0.05, 0, 0], + [1, 0, 0], + ]) + const bonds = nearest_neighbor(structure, { + scaling_factor: 1.2, + min_bond_dist: 0.1, + }) + expect(bonds).toHaveLength(2) + expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 2, 1]) + }) + + test(`should handle structures with multiple equidistant nearest neighbors`, () => { + const structure = make_struct([ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ]) + const bonds = nearest_neighbor(structure, { + scaling_factor: 1.1, + min_bond_dist: 0.5, + }) + expect(bonds).toHaveLength(3) + expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1]) + }) + + test(`should handle empty structures`, () => { + const structure = make_struct([]) + const bonds = nearest_neighbor(structure) + expect(bonds).toHaveLength(0) + }) + + test(`should respect the scaling_factor`, () => { + const structure = make_struct([ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1.5, 0, 0], + ]) + const bonds = nearest_neighbor(structure, { + scaling_factor: 1.4, + min_bond_dist: 0.5, + }) + expect(bonds).toHaveLength(4) + expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1]) + expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1]) + expect(bonds).toContainEqual([[1, 0, 0], [1.5, 0, 0], 1, 4, 0.5]) + }) +})