Skip to content

Commit

Permalink
received emails from confused researchers re: pytest-examples this mo…
Browse files Browse the repository at this point in the history
…rning. get rid of it
  • Loading branch information
lucidrains committed May 10, 2024
1 parent 85f03c5 commit 34b9e97
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
run: |
rye sync
- name: Run pytest
run: rye run pytest --cov=. tests/test_examples_readme.py
run: rye run pytest --cov=. tests/
113 changes: 53 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

```

## Residual VQ
Expand All @@ -49,13 +48,13 @@ x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
# (1, 1024, 256), (1, 1024, 8), (1, 8)

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
print(all_codes.shape)
#> torch.Size([8, 1, 1024, 256])

# (8, 1, 1024, 256)
```

Furthermore, <a href="https://arxiv.org/abs/2203.01941">this paper</a> uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
Expand All @@ -77,8 +76,8 @@ residual_vq = ResidualVQ(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])

# (1, 1024, 256), (1, 1024, 8), (1, 8)
```

<a href="https://arxiv.org/abs/2305.02765">A recent paper</a> further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing `GroupedResidualVQ`
Expand All @@ -97,9 +96,8 @@ residual_vq = GroupedResidualVQ(
x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])

# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
```

## Initialization
Expand All @@ -120,8 +118,8 @@ residual_vq = ResidualVQ(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4])

# (1, 1024, 256), (1, 1024, 4), (1, 4)
```

## Increasing codebook usage
Expand All @@ -144,8 +142,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

# (1, 1024, 256), (1, 1024), (1,)
```

### Cosine similarity
Expand All @@ -164,8 +162,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

# (1, 1024, 256), (1, 1024), (1,)
```

### Expiring stale codes
Expand All @@ -184,8 +182,8 @@ vq = VectorQuantize(

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])

# (1, 1024, 256), (1, 1024), (1,)
```

### Orthogonal regularization loss
Expand All @@ -209,9 +207,8 @@ vq = VectorQuantize(

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)

# loss now contains the orthogonal regularization loss with the weight as assigned
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
```

### Multi-headed VQ
Expand All @@ -235,8 +232,8 @@ vq = VectorQuantize(

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap)
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1])

# (1, 256, 32, 32), (1, 32, 32, 8), (1,)

```

Expand All @@ -259,8 +256,8 @@ quantizer = RandomProjectionQuantizer(

x = torch.randn(1, 1024, 512)
indices = quantizer(x)
print(indices.shape)
#> torch.Size([1, 1024, 16])

# (1, 1024, 16)
```

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`
Expand All @@ -285,16 +282,14 @@ Thanks goes out to [@sekstini](https://github.com/sekstini) for porting over thi
import torch
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)
quantizer = FSQ(
levels = [8, 5, 5, 5]
)

x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)

print(xhat.shape)
#> torch.Size([1, 1024, 4])
print(indices.shape)
#> torch.Size([1, 1024])
# (1, 1024, 4), (1, 1024)

assert torch.all(xhat == quantizer.indices_to_codes(indices))
```
Expand All @@ -318,12 +313,12 @@ x = torch.randn(1, 1024, 256)
residual_fsq.eval()

quantized, indices = residual_fsq(x)
print(quantized.shape, indices.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8])

# (1, 1024, 256), (1, 1024, 8)

quantized_out = residual_fsq.get_output_from_indices(indices)
print(quantized_out.shape)
#> torch.Size([1, 1024, 256])

# (1, 1024, 256)

assert torch.all(quantized == quantized_out)
```
Expand Down Expand Up @@ -357,8 +352,8 @@ quantizer = LFQ(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32), ()

assert (quantized == quantizer.indices_to_codes(indices)).all()
```
Expand All @@ -379,13 +374,12 @@ quantizer = LFQ(
seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

# assert seq.shape == quantized.shape
assert seq.shape == quantized.shape

# video_feats = torch.randn(1, 16, 10, 32, 32)
# quantized, *_ = quantizer(video_feats)

# assert video_feats.shape == quantized.shape
video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape
```

Or support multiple codebooks
Expand All @@ -403,8 +397,8 @@ quantizer = LFQ(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats)
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 4]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32, 4), ()

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
Expand All @@ -427,12 +421,12 @@ x = torch.randn(1, 1024, 256)
residual_lfq.eval()

quantized, indices, commit_loss = residual_lfq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8])

# (1, 1024, 256), (1, 1024, 8), (8)

quantized_out = residual_lfq.get_output_from_indices(indices)
print(quantized_out.shape)
#> torch.Size([1, 1024, 256])

# (1, 1024, 256)

assert torch.all(quantized == quantized_out)
```
Expand Down Expand Up @@ -460,8 +454,8 @@ quantizer = LatentQuantize(
image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, loss = quantizer(image_feats)
print(quantized.shape, indices.shape, loss.shape)
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])

# (1, 16, 32, 32), (1, 32, 32), ()

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
Expand All @@ -483,13 +477,13 @@ quantizer = LatentQuantize(

seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)
print(quantized.shape)
#> torch.Size([1, 32, 16])

# (1, 32, 16)

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
print(quantized.shape)
#> torch.Size([1, 16, 10, 32, 32])

# (1, 16, 10, 32, 32)

```

Expand All @@ -499,23 +493,22 @@ Or support multiple codebooks
import torch
from vector_quantize_pytorch import LatentQuantize

levels = [4, 8, 16]
dim = 9
num_codebooks = 3

model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
model = LatentQuantize(
levels = [4, 8, 16],
dim = 9,
num_codebooks = 3
)

input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)
print(output_tensor.shape, indices.shape, loss.shape)
#> torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([])

# (2, 3, 9), (2, 3, 3), ()

assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
assert loss.item() >= 0
```


## Citations

```bibtex
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ managed = true
dev-dependencies = [
"ruff>=0.4.2",
"pytest>=8.2.0",
"pytest-examples>=0.0.10",
"pytest-cov>=5.0.0",
]

Expand Down
27 changes: 0 additions & 27 deletions tests/test_examples_readme.py

This file was deleted.

Loading

8 comments on commit 34b9e97

@MisterBourbaki
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @lucidrains ,
What was the issue with the previous test? It was to make sure the examples in the README were up to date. Both test suites can coexist.
I am not sure this is a good approach to simply get rid of it?

I saw a comment on one of my commit regarding the tests, maybe it is related? I was planning to have a look.

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on 34b9e97 May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MisterBourbaki i didn't get rid of the tests, just duplicated the readme into a new test file

it is clearer that way for researchers

@MisterBourbaki
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood what you did, but I think both should coexist.
What was there issue? If they do not want to contribute, and just want to modify the code to their liking, no tests will be run. If they do want to contribute, they should do it in a good way (but then, it is my humble opinion, and in the end your the boss here :) )

@MisterBourbaki
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, with your test suite (which is fine by the way, and needed!) they should have the same issue that with the previous test suite...

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on 34b9e97 May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MisterBourbaki they wanted to modify something, tests broke, but had no idea where the tests were. it was too much trying to explain to them how the tests were in the readme, and how using print statements does some sort of implicit assert

@lucidrains
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MisterBourbaki it is fine the way it is, if you see anything missing, please feel free to open a PR and just add it to the test file.

@MisterBourbaki
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally understant your point :) It is hard to find que equilibrium between code quality and ease of use. I will keep that in mind

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on 34b9e97 May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MisterBourbaki sounds good, thank you for getting it started!

and now that tests are in place, i expect to see a contribution or two of a new quantizer from a research paper 😄

Please sign in to comment.