Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a fix for mismatched logit and mask tensor shapes #236

Conversation

xaviruvpadhiyar98
Copy link
Contributor

Findings with pythia-70m-deduped
Vocab size is 50304 (mentioned in original config) in https://huggingface.co/EleutherAI/pythia-70m-deduped/blob/main/config.json
Vocab size is 50254 (when we run - tokenizer.vocab_size from AutoTokenizer)
Vocab size is 50277 (when we run - len(tokenizer.get_vocab()) from AutoTokenizer)
Logit size is 50304 (when we call self.model(...) )

Now for GPT2 Vocab size and Logit size both are same 50257

This fix would create mask with size of Logit rather than vocab

@@ -129,9 +129,10 @@ def create_proposal(
self.pstates = new_pstates

masks = []
mask_shape = (len(logits[0]),)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is going to work, I'd like to have an explanation as to why the vocab size does not correspond to the logit size for these models before merging. Do you have an idea?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I can't say for sure on how vocabs are created in pythia models, may be due to how long these models are being trained or these vocab's are shared between different models who are trained for large set of tokens, I just found that we creating a mask for logits with vocab size and since vocab size's are different at each level incase of pythia/GTPNeoXTokenizer
I thought of creating a mask directly with logit size,

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the EleuterAI discord release channel "the embedding dim is rounded up to the nearest multiple of 128 * model-parallel-size for training efficiency" in response to this question. the higher token ids are there but don't mean anything. I think this means that @xaviruvpadhiyar98's approach is correct; rounding 50254 up to nearest multiple of 128 gives you 50304.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maedoc for confirming..

@brandonwillard brandonwillard force-pushed the regex-masking-dimension-compatibility-#213 branch from 9b1d966 to 89162b6 Compare September 7, 2023 00:12
@brandonwillard brandonwillard linked an issue Sep 7, 2023 that may be closed by this pull request
@brandonwillard brandonwillard added bug text Linked to text generation structured generation Linked to structured generation labels Sep 7, 2023
@brandonwillard brandonwillard force-pushed the regex-masking-dimension-compatibility-#213 branch from 89162b6 to 4f82db0 Compare September 7, 2023 00:26
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rlouf, is there a good test for this that can be run in CI?

@rlouf
Copy link
Member

rlouf commented Sep 7, 2023

@xaviruvpadhiyar98 thank you for contributing. Could you add a test in test_integration_transformers.py that uses one of the models here and where the discrepancy between logit and vocab size happens?

@xaviruvpadhiyar98
Copy link
Contributor Author

xaviruvpadhiyar98 commented Sep 9, 2023

@rlouf I have created a test case utilizing "hf-internal-testing/tiny-random-GPTNeoXForQuestionAnswering" model
I have simply check shape of logits and shape of result
I believe there is already a test case called test_choice_proposal which checks for something similar and it passes

def test_transformers_logits_vocab_size():
    rng = torch.Generator()
    rng.manual_seed(0)

    model_name = "hf-internal-testing/tiny-random-GPTNeoXForQuestionAnswering"
    model = models.transformers(model_name, device="cpu")
    generator = generate.choice(model, ["True", "False"])

    logits = torch.ones(len(model.tokenizer.vocabulary))
    result = generator.create_proposal(torch.tensor([[]]), logits)

    # logits = tensor([1., 1., 1.,  ..., 1., 1., 1.]) 1024 torch.Size([1024])
    # result = tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf]]) 1024 torch.Size([1, 1024])
    assert logits.shape[0] == result.shape[1]

Anyways ran the test cases and
== 130 passed, 4 xfailed in 39.63s ==
I had to use

pytest -vvv -W ignore::DeprecationWarning

coz it was throwing some - "pkg_resources is deprecated as an API."

I also had to update the conda yml file, which I have not commited

# To use:
#
#   $ conda env create -f environment.yml  # `mamba` works too for this command
#   $ conda activate outlines-dev
#
name: outlines-dev
channels:
  - conda-forge
  - huggingface
  - pytorch
  - nvidia
  - xformers
dependencies:
  - python==3.10.0
  - jinja2
  - numpy
  - pillow
  - pydantic
  - scipy
  - diffusers
  - pytest
  - pre-commit
  - transformers
  - pip
  - cuda
  - pytorch
  - torchvision
  - torchaudio
  - xformers
  # - pip:
  #   - -e ".[test]"

I had issues with creating an environment with the default packages, I have added following package and installed everything with micromamba and it works flawlessly.

micromamba env create -f environment.yml
pytorch/noarch (check zst)                          Checked  0.4s
nvidia/linux-64 (check zst)                         Checked  0.4s
pytorch/linux-64 (check zst)                        Checked  0.4s
huggingface/linux-64 (check zst)                    Checked  0.4s
nvidia/noarch (check zst)                           Checked  0.4s
huggingface/noarch (check zst)                      Checked  1.4s
xformers/noarch (check zst)                         Checked  1.4s
xformers/linux-64 (check zst)                       Checked  1.4s
pytorch/linux-64                                   169.3kB @   2.4MB/s  0.1s
nvidia/linux-64                                    142.3kB @   1.2MB/s  0.0s
huggingface/noarch                                  12.6kB @  28.2kB/s  0.4s
xformers/linux-64                                    4.3kB @   4.9kB/s  0.4s
pytorch/noarch                                      10.4kB @  11.3kB/s  0.0s
nvidia/noarch                                        4.9kB @   5.1kB/s  0.0s
xformers/noarch                                    115.0 B @  75.0 B/s  1.4s
huggingface/linux-64                                 5.6kB @   3.6kB/s  1.5s
conda-forge/noarch                                  12.3MB @   3.5MB/s  3.6s
conda-forge/linux-64                                30.1MB @   5.0MB/s  6.2s

Transaction

  Prefix: /home/dhruv/micromamba/envs/outlines-dev

  Updating specs:

   - python==3.10.0
   - jinja2
   - numpy
   - pillow
   - pydantic
   - scipy
   - diffusers
   - pytest
   - pre-commit
   - transformers
   - pip
   - cuda
   - pytorch
   - torchvision
   - torchaudio
   - xformers


  Package                         Version  Build                     Channel          Size
────────────────────────────────────────────────────────────────────────────────────────────
  Install:
────────────────────────────────────────────────────────────────────────────────────────────

  + cuda-demo-suite              12.2.140  0                         nvidia            5MB
  + cuda-documentation           12.2.140  0                         nvidia           91kB
  + cuda-nvml-dev                12.2.140  0                         nvidia           93kB
  + cuda-cudart                  12.2.140  0                         nvidia          197kB
  + cuda-nvrtc                   12.2.140  0                         nvidia           21MB
  + cuda-opencl                  12.2.140  0                         nvidia           11kB
  + libcublas                    12.2.5.6  0                         nvidia          372MB
  + libcufft                   11.0.8.103  0                         nvidia           87MB
  + libcufile                    1.7.2.10  0                         nvidia            1MB
  + libcurand                  10.3.3.141  0                         nvidia           54MB
  + libcusolver                11.5.2.141  0                         nvidia          117MB
  + libcusparse                12.1.2.141  0                         nvidia          177MB
  + libnpp                       12.2.1.4  0                         nvidia          146MB
  + libnvjitlink                 12.2.140  0                         nvidia           18MB
  + libnvjpeg                    12.2.2.4  0                         nvidia            3MB
  + cuda-cuobjdump               12.2.140  0                         nvidia          246kB
  + cuda-cuxxfilt                12.2.140  0                         nvidia          292kB
  + cuda-nvcc                    12.2.140  0                         nvidia           59MB
  + cuda-nvprune                 12.2.140  0                         nvidia           67kB
  + cuda-cccl                    12.2.140  0                         nvidia            1MB
  + cuda-driver-dev              12.2.140  0                         nvidia           17kB
  + cuda-profiler-api            12.2.140  0                         nvidia           19kB
  + cuda-cupti                   12.2.142  0                         nvidia           16MB
  + cuda-gdb                     12.2.140  0                         nvidia            6MB
  + cuda-nvdisasm                12.2.140  0                         nvidia           50MB
  + cuda-nvprof                  12.2.142  0                         nvidia            5MB
  + cuda-nvtx                    12.2.140  0                         nvidia           58kB
  + cuda-sanitizer-api           12.2.140  0                         nvidia           18MB
  + cuda-nsight                  12.2.144  0                         nvidia          119MB
  + nsight-compute             2023.2.2.3  0                         nvidia          831MB
  + cuda-nvrtc-dev               12.2.140  0                         nvidia           12kB
  + cuda-opencl-dev              12.2.140  0                         nvidia           59kB
  + libcublas-dev                12.2.5.6  0                         nvidia           76kB
  + libcufft-dev               11.0.8.103  0                         nvidia           14kB
  + gds-tools                    1.7.2.10  0                         nvidia           43MB
  + libcufile-dev                1.7.2.10  0                         nvidia           14kB
  + libcurand-dev              10.3.3.141  0                         nvidia          460kB
  + libcusolver-dev            11.5.2.141  0                         nvidia           50kB
  + libcusparse-dev            12.1.2.141  0                         nvidia          177MB
  + libnpp-dev                   12.2.1.4  0                         nvidia          550kB
  + libnvjitlink-dev             12.2.140  0                         nvidia           15MB
  + libnvjpeg-dev                12.2.2.4  0                         nvidia           13kB
  + cuda-libraries                 12.2.2  0                         nvidia            2kB
  + cuda-compiler                  12.2.2  0                         nvidia            1kB
  + cuda-cudart-dev              12.2.140  0                         nvidia          393kB
  + cuda-cupti-static            12.2.142  0                         nvidia           12MB
  + cuda-nvvp                    12.2.142  0                         nvidia          120MB
  + cuda-command-line-tools        12.2.2  0                         nvidia            1kB
  + cuda-nsight-compute            12.2.2  0                         nvidia            1kB
  + cuda-nvrtc-static            12.2.140  0                         nvidia           18MB
  + libcublas-static             12.2.5.6  0                         nvidia          413MB
  + libcufft-static            11.0.8.103  0                         nvidia          174MB
  + libcufile-static             1.7.2.10  0                         nvidia            4MB
  + libcurand-static           10.3.3.141  0                         nvidia           55MB
  + libcusolver-static         11.5.2.141  0                         nvidia           76MB
  + libcusparse-static         12.1.2.141  0                         nvidia          183MB
  + libnpp-static                12.2.1.4  0                         nvidia          143MB
  + libnvjpeg-static             12.2.2.4  0                         nvidia            3MB
  + cuda-runtime                   12.2.2  0                         nvidia            1kB
  + cuda-cudart-static           12.2.140  0                         nvidia            1MB
  + cuda-libraries-dev             12.2.2  0                         nvidia            2kB
  + cuda-libraries-static          12.2.2  0                         nvidia            2kB
  + cuda-visual-tools              12.2.2  0                         nvidia            2kB
  + cuda-tools                     12.2.2  0                         nvidia            1kB
  + cuda-toolkit                   12.2.2  0                         nvidia            1kB
  + cuda                           12.2.2  0                         nvidia            1kB
  + pytorch-mutex                     1.0  cpu                       pytorch           3kB
  + _libgcc_mutex                     0.1  conda_forge               conda-forge       3kB
  + libstdcxx-ng                   13.1.0  hfd8a6a1_0                conda-forge       4MB
  + python_abi                       3.10  3_cp310                   conda-forge       6kB
  + libgfortran5                   13.1.0  h15d22d2_0                conda-forge       1MB
  + ca-certificates             2023.7.22  hbcca054_0                conda-forge     150kB
  + ld_impl_linux-64                 2.40  h41732ed_0                conda-forge     705kB
  + libgcc-ng                      13.1.0  he5830b7_0                conda-forge     776kB
  + libzlib                        1.2.13  hd590300_5                conda-forge      62kB
  + zstd                            1.5.5  hfc55251_0                conda-forge     545kB
  + llvm-openmp                    16.0.6  h4dfa4b3_0                conda-forge      42MB
  + _openmp_mutex                     4.5  2_kmp_llvm                conda-forge       6kB
  + libgfortran-ng                 13.1.0  h69a702a_0                conda-forge      23kB
  + rdma-core                        28.9  h59595ed_1                conda-forge       4MB
  + libnuma                        2.0.16  h0b41bf4_1                conda-forge      41kB
  + xorg-libxdmcp                   1.1.3  h7f98852_0                conda-forge      19kB
  + pthread-stubs                     0.4  h36c2ea0_1001             conda-forge       6kB
  + sleef                           3.5.1  h9b69904_2                conda-forge       2MB
  + libcrc32c                       1.1.2  h9c3ff4c_0                conda-forge      20kB
  + keyutils                        1.6.1  h166bdaf_0                conda-forge     118kB
  + icu                              73.2  h59595ed_0                conda-forge      12MB
  + libiconv                         1.17  h166bdaf_0                conda-forge       1MB
  + libev                            4.33  h516909a_1                conda-forge     106kB
  + gflags                          2.2.2  he1b5a44_1004             conda-forge     117kB
  + libsqlite                      3.43.0  h2797004_0                conda-forge     841kB
  + ncurses                           6.4  hcb278e6_0                conda-forge     881kB
  + libuuid                        2.38.1  h0b41bf4_0                conda-forge      34kB
  + libffi                          3.4.2  h7f98852_5                conda-forge      58kB
  + aws-c-common                    0.9.0  hd590300_0                conda-forge     198kB
  + lz4-c                           1.9.4  hcb278e6_0                conda-forge     143kB
  + libbrotlicommon                 1.0.9  h166bdaf_9                conda-forge      71kB
  + bzip2                           1.0.8  h7f98852_4                conda-forge     496kB
  + xorg-libxau                    1.0.11  hd590300_0                conda-forge      14kB
  + libdeflate                       1.18  h0b41bf4_0                conda-forge      65kB
  + xz                              5.2.6  h166bdaf_0                conda-forge     418kB
  + lerc                            4.0.0  h27087fc_0                conda-forge     282kB
  + zlib                           1.2.13  hd590300_5                conda-forge      93kB
  + c-ares                         1.19.1  hd590300_0                conda-forge     113kB
  + snappy                         1.1.10  h9fff704_0                conda-forge      39kB
  + re2                        2023.03.02  h8c504da_0                conda-forge     201kB
  + libabseil                  20230125.3  cxx17_h59595ed_0          conda-forge       1MB
  + libutf8proc                     2.8.0  h166bdaf_0                conda-forge     101kB
  + libprotobuf                   3.21.12  hfc55251_2                conda-forge       2MB
  + ninja                          1.11.1  h924138e_0                conda-forge       2MB
  + yaml                            0.2.5  h7f98852_2                conda-forge      89kB
  + openssl                         3.1.2  hd590300_0                conda-forge       3MB
  + cudatoolkit                    11.8.0  h4ba93d1_12               conda-forge     716MB
  + xxhash                          0.8.2  hd590300_0                conda-forge      98kB
  + libpng                         1.6.39  h753d276_0                conda-forge     283kB
  + libwebp-base                    1.3.1  hd590300_0                conda-forge     400kB
  + libjpeg-turbo                 2.1.5.1  h0b41bf4_0                conda-forge     491kB
  + tk                             8.6.12  h27826a3_0                conda-forge       3MB
  + libnsl                          2.0.0  h7f98852_0                conda-forge      31kB
  + libopenblas                    0.3.24  pthreads_h413a1c8_0       conda-forge       5MB
  + ucx                            1.14.1  h64cca9d_4                conda-forge      15MB
  + glog                            0.6.0  h6f12383_0                conda-forge     114kB
  + libedit                  3.1.20191231  he28a2e2_2                conda-forge     124kB
  + readline                          8.2  h8228510_1                conda-forge     281kB
  + aws-c-compression              0.2.17  h4d4d85c_2                conda-forge      19kB
  + aws-checksums                  0.1.17  h4d4d85c_1                conda-forge      50kB
  + aws-c-sdkutils                 0.1.12  h4d4d85c_1                conda-forge      53kB
  + libbrotlienc                    1.0.9  h166bdaf_9                conda-forge     265kB
  + libbrotlidec                    1.0.9  h166bdaf_9                conda-forge      33kB
  + libxcb                           1.15  h0b41bf4_0                conda-forge     384kB
  + libxml2                        2.11.5  h232c23b_1                conda-forge     706kB
  + orc                             1.9.0  h2f23424_1                conda-forge       1MB
  + libevent                       2.1.12  hf998b51_1                conda-forge     427kB
  + s2n                            1.3.49  h06160fa_0                conda-forge     371kB
  + libssh2                        1.11.0  h0841786_0                conda-forge     271kB
  + libnghttp2                     1.52.0  h61bc06f_0                conda-forge     622kB
  + aws-c-cal                       0.6.1  hc309b26_1                conda-forge      51kB
  + libgrpc                        1.54.3  hb20ce57_0                conda-forge       6MB
  + freetype                       2.12.1  hca18f0e_1                conda-forge     626kB
  + libtiff                         4.5.1  h8b53f26_1                conda-forge     417kB
  + libblas                         3.9.0  18_linux64_openblas       conda-forge      15kB
  + krb5                           1.21.2  h659d440_0                conda-forge       1MB
  + sqlite                         3.43.0  h2c6b66d_0                conda-forge     831kB
  + libhwloc                        2.9.2  default_h554bfaf_1009     conda-forge       3MB
  + libthrift                      0.18.1  h8fd135c_2                conda-forge       4MB
  + aws-c-io                      0.13.32  he9a53bd_1                conda-forge     155kB
  + openjpeg                        2.5.0  hfec8fc6_2                conda-forge     352kB
  + lcms2                            2.15  haa2dc70_1                conda-forge     242kB
  + libcblas                        3.9.0  18_linux64_openblas       conda-forge      14kB
  + liblapack                       3.9.0  18_linux64_openblas       conda-forge      14kB
  + libcurl                         8.2.1  hca28451_0                conda-forge     373kB
  + tbb                         2021.10.0  h00ab1b0_0                conda-forge     186kB
  + aws-c-http                     0.7.11  h00aa349_4                conda-forge     194kB
  + aws-c-event-stream              0.3.1  h2e3709c_4                conda-forge      54kB
  + magma                           2.6.2  hc72dce7_0                conda-forge     243MB
  + libgoogle-cloud                2.12.0  hac9eb74_1                conda-forge      46MB
  + mkl                          2022.2.1  h84fe81f_16997            conda-forge     165MB
  + aws-c-mqtt                      0.9.3  hb447be9_1                conda-forge     162kB
  + aws-c-auth                      0.7.3  h28f7589_1                conda-forge     102kB
  + aws-c-s3                       0.3.14  hf3aad02_1                conda-forge      87kB
  + aws-crt-cpp                    0.21.0  hb942446_5                conda-forge     320kB
  + aws-sdk-cpp                   1.10.57  h85b1a90_19               conda-forge       4MB
  + libarrow                       12.0.1  hb87d912_8_cpu            conda-forge      28MB
  + tzdata                          2023c  h71feb2d_0                conda-forge     118kB
  + cuda-version                     11.8  h70ddcb2_2                conda-forge      21kB
  + python                         3.10.0  h543edf9_3_cpython        conda-forge      31MB
  + nccl                         2.18.5.1  h6103f9b_1                conda-forge     117MB
  + cudnn                       8.8.0.121  h838ba91_3                conda-forge     479MB
  + wheel                          0.41.2  pyhd8ed1ab_0              conda-forge      57kB
  + setuptools                     68.1.2  pyhd8ed1ab_0              conda-forge     462kB
  + pip                            23.2.1  pyhd8ed1ab_0              conda-forge       1MB
  + pycparser                        2.21  pyhd8ed1ab_0              conda-forge     103kB
  + mypy_extensions                 1.0.0  pyha770c72_0              conda-forge      10kB
  + pysocks                         1.7.1  pyha2e5f31_6              conda-forge      19kB
  + attrs                          23.1.0  pyh71513ae_1              conda-forge      55kB
  + pytz                     2023.3.post1  pyhd8ed1ab_0              conda-forge     187kB
  + python-tzdata                  2023.3  pyhd8ed1ab_0              conda-forge     143kB
  + six                            1.16.0  pyh6c4a22f_0              conda-forge      14kB
  + joblib                          1.3.2  pyhd8ed1ab_0              conda-forge     221kB
  + click                           8.1.7  unix_pyh707e725_0         conda-forge      84kB
  + dill                            0.3.7  pyhd8ed1ab_0              conda-forge      88kB
  + fsspec                       2023.9.0  pyh1a96a4e_0              conda-forge     123kB
  + zipp                           3.16.2  pyhd8ed1ab_0              conda-forge      19kB
  + nodeenv                         1.8.0  pyhd8ed1ab_0              conda-forge      34kB
  + cfgv                            3.3.1  pyhd8ed1ab_0              conda-forge      11kB
  + distlib                         0.3.7  pyhd8ed1ab_0              conda-forge     274kB
  + typing_extensions               4.7.1  pyha770c72_0              conda-forge      36kB
  + tomli                           2.0.1  pyhd8ed1ab_0              conda-forge      16kB
  + iniconfig                       2.0.0  pyhd8ed1ab_0              conda-forge      11kB
  + exceptiongroup                  1.1.3  pyhd8ed1ab_0              conda-forge      19kB
  + pluggy                          1.3.0  pyhd8ed1ab_0              conda-forge      23kB
  + charset-normalizer              3.2.0  pyhd8ed1ab_0              conda-forge      46kB
  + idna                              3.4  pyhd8ed1ab_0              conda-forge      57kB
  + certifi                     2023.7.22  pyhd8ed1ab_0              conda-forge     154kB
  + colorama                        0.4.6  pyhd8ed1ab_0              conda-forge      25kB
  + packaging                        23.1  pyhd8ed1ab_0              conda-forge      46kB
  + filelock                       3.12.3  pyhd8ed1ab_0              conda-forge      15kB
  + dataclasses                       0.8  pyhc8e2a94_3              conda-forge      10kB
  + python-dateutil                 2.8.2  pyhd8ed1ab_0              conda-forge     246kB
  + importlib-metadata              6.8.0  pyha770c72_0              conda-forge      26kB
  + typing_inspect                  0.9.0  pyhd8ed1ab_0              conda-forge      15kB
  + typing-extensions               4.7.1  hd8ed1ab_0                conda-forge      10kB
  + tqdm                           4.66.1  pyhd8ed1ab_0              conda-forge      89kB
  + pytest                          7.4.2  pyhd8ed1ab_0              conda-forge     245kB
  + importlib_metadata              6.8.0  hd8ed1ab_0                conda-forge       9kB
  + platformdirs                   3.10.0  pyhd8ed1ab_0              conda-forge      20kB
  + async-timeout                   4.0.3  pyhd8ed1ab_0              conda-forge      11kB
  + annotated-types                 0.5.0  pyhd8ed1ab_0              conda-forge      16kB
  + pyre-extensions                0.0.29  pyhd8ed1ab_0              conda-forge      15kB
  + virtualenv                    20.24.4  pyhd8ed1ab_0              conda-forge       3MB
  + brotli-python                   1.0.9  py310hd8f1fbe_9           conda-forge     326kB
  + multidict                       6.0.4  py310h1fa729e_0           conda-forge      53kB
  + frozenlist                      1.4.0  py310h2372a71_0           conda-forge      54kB
  + markupsafe                      2.1.3  py310h2372a71_0           conda-forge      24kB
  + python-xxhash                   3.3.0  py310h2372a71_0           conda-forge      23kB
  + safetensors                     0.3.3  py310hcb5633a_0           conda-forge       1MB
  + tokenizers                     0.13.3  py310he1f1126_0           conda-forge       4MB
  + regex                        2023.8.8  py310h2372a71_0           conda-forge     348kB
  + pyyaml                          6.0.1  py310h2372a71_0           conda-forge     171kB
  + pillow                         10.0.0  py310h582fbeb_0           conda-forge      47MB
  + numpy                          1.25.2  py310ha4c1d20_0           conda-forge       7MB
  + cffi                           1.15.1  py310h255011f_3           conda-forge     237kB
  + multiprocess                  0.70.15  py310h2372a71_0           conda-forge     243kB
  + pydantic-core                   2.6.3  py310hcb5633a_0           conda-forge       1MB
  + yarl                            1.9.2  py310h2372a71_0           conda-forge      97kB
  + pyarrow                        12.0.1  py310h0576679_8_cpu       conda-forge       4MB
  + pandas                          2.1.0  py310hcc13569_0           conda-forge      13MB
  + ukkonen                         1.0.1  py310hbf28c38_3           conda-forge      13kB
  + pytorch                        1.13.1  cuda112py310he33e0d6_200  conda-forge     380MB
  + xformers                       0.0.21  py310hd41b1e2_0           conda-forge     618kB
  + urllib3                         2.0.4  pyhd8ed1ab_0              conda-forge      98kB
  + aiosignal                       1.3.1  pyhd8ed1ab_0              conda-forge      13kB
  + jinja2                          3.1.2  pyhd8ed1ab_1              conda-forge     101kB
  + sacremoses                     0.0.53  pyhd8ed1ab_0              conda-forge     437kB
  + pydantic                        2.3.0  pyhd8ed1ab_0              conda-forge     258kB
  + identify                       2.5.27  pyhd8ed1ab_0              conda-forge      78kB
  + requests                       2.31.0  pyhd8ed1ab_0              conda-forge      57kB
  + pre-commit                      3.4.0  pyha770c72_0              conda-forge     180kB
  + pooch                           1.7.0  pyha770c72_3              conda-forge      51kB
  + huggingface_hub                0.16.4  pyhd8ed1ab_0              conda-forge     183kB
  + diffusers                      0.18.2  pyhd8ed1ab_0              conda-forge     373kB
  + torchaudio                     0.13.1  py310_cpu                 pytorch           7MB
  + aiohttp                         3.8.5  py310h2372a71_0           conda-forge     454kB
  + torchvision                    0.14.1  cuda112py310hb1d1f80_1    conda-forge       7MB
  + scipy                          1.11.2  py310ha4c1d20_0           conda-forge      16MB
  + datasets                       2.14.4  pyhd8ed1ab_0              conda-forge     347kB
  + transformers                   4.33.1  pyhd8ed1ab_0              conda-forge       3MB

  Summary:

  Install: 247 packages

  Total download: 6GB

────────────────────────────────────────────────────────────────────────────────────────────


Confirm changes: [Y/n] Y

@brandonwillard brandonwillard force-pushed the regex-masking-dimension-compatibility-#213 branch from e800c9e to d2051ef Compare September 9, 2023 17:51
@brandonwillard brandonwillard force-pushed the regex-masking-dimension-compatibility-#213 branch from d2051ef to 96b4c03 Compare September 12, 2023 00:10
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rebased and updated the test so that it simulates the logits and vocab size mismatch.

@brandonwillard brandonwillard merged commit d46a792 into outlines-dev:main Sep 12, 2023
4 checks passed
@brandonwillard brandonwillard changed the title Added a fix for missed match tensor size #213 Added a fix for mismatched tensor size #213 Sep 17, 2023
@brandonwillard brandonwillard changed the title Added a fix for mismatched tensor size #213 Added a fix for mismatched logit and mask tensor shapes Sep 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug structured generation Linked to structured generation text Linked to text generation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Regex masking dimension compatibility
4 participants