Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

procrustes alignment for pytorch #8577

Open
heth27 opened this issue Aug 9, 2024 · 3 comments
Open

procrustes alignment for pytorch #8577

heth27 opened this issue Aug 9, 2024 · 3 comments

Comments

@heth27
Copy link

heth27 commented Aug 9, 2024

🚀 The feature

Orthogonal procrustes alignment

Motivation, pitch

Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.

There is a variant in scipy
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html

Alternatives

No response

Additional context

The implementation I'm using, don't know if it is any good.

def procrustes(pts1: torch.Tensor, pts2: torch.Tensor):
    assert pts1.shape == pts2.shape, f"{pts1.shape} != {pts2.shape}"
    assert pts1.shape[-1] == 3 and len(pts1.shape) == 2, f"{pts1.shape}"
    # estimate a sim3 transformation to align two point clouds
    # find M = argmin ||P1 - M @ P2||
    t1 = pts1.mean(dim=0)
    t2 = pts2.mean(dim=0)
    pts1 = pts1 - t1[None, :]
    pts2 = pts2 - t2[None, :]

    s1 = pts1.square().sum(dim=-1).mean().sqrt()
    s2 = pts2.square().sum(dim=-1).mean().sqrt()
    pts1 = pts1 / s1
    pts2 = pts2 / s2
    try:

        U, _, V = (pts1.T @ pts2).double().svd()
        U: torch.Tensor = U
        V: torch.Tensor = V
    except:
        print("Procustes failed: SVD did not converge!")
        s = s1 / s2
        return 1, torch.eye(3, device=pts1.device), torch.zeros_like(t1)
    # build rotation matrix
    R = (U @ V.T).float()
    if R.det() < 0:
        R[:, 2] *= -1
    s = s1 / s2
    t = t1 - s * t2 @ R.T

    # use as mat4: [sR, t] @ pts2
    # or as s * R @ pts2 + t

    # s, R, mean_1, mean_2 = procrustes(pts1, pts2)
    #
    # procrustes_aligned = torch.einsum("jd, od -> jo", coords3d_pred_rel_dataset_format[index_in_batch] - mean_2,
    #                                               s * R) + mean_1
    return s, R, t1, t2

example usage:

s, R, mean_1, mean_2 = procrustes(coords_3d_true,
                                              coords_3d_prediction)
procrustes_aligned = torch.einsum("jd, od -> jo", coords_3d_prediction - mean_2,
                                              s * R) + mean_1
@heth27
Copy link
Author

heth27 commented Aug 11, 2024

If this is better suited for e.g. torchmetrics (https://lightning.ai/docs/torchmetrics/stable/) this would also be good to know

@NicolasHug
Copy link
Member

Hi @heth27 and thank you for the feature request. Torchvision doesn't really have a holistic support for 3D data in general, so I'm not sure procrustes alignement would be in scope. We typically add such metrics when they directly relate to one of the CV tasks that torchvision supports (classification, detection, etc.), but 3D human pose is not yet in scope.
Thank you for providing a snippet, I hope it can be useful to users looking for this exact feature.

@NicolasHug
Copy link
Member

If this is better suited for e.g. torchmetrics (lightning.ai/docs/torchmetrics/stable) this would also be good to know

It might be in scope for torchmetrics, although note that this isn't owned by the pytorch org, so we don't have any weight in the decision process over there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants