Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP requires larger LR to converge at similar rate as DeepSpeed #2624

Closed
2 of 4 tasks
fabianlim opened this issue Apr 4, 2024 · 34 comments
Closed
2 of 4 tasks

FSDP requires larger LR to converge at similar rate as DeepSpeed #2624

fabianlim opened this issue Apr 4, 2024 · 34 comments

Comments

@fabianlim
Copy link
Contributor

fabianlim commented Apr 4, 2024

System Info

accelerate: 0.28
deepspeed: 0.14
torch: 2.2.1

FSDP configs

fsdp_config:
  
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE 
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1 
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_cpu_ram_efficient_loading: true process
  fsdp_sync_module_states: true 
mixed_precision: 'no'

machine_rank: 0 # rank of the machine where accelerate is launched
num_machines: 1
num_processes: 1  # default, override with --num_processes
rdzv_backend: static
same_network: true

Deepspeed config

{
    "train_micro_batch_size_per_gpu": "auto",
    "steps_per_print": 1,
    "zero_optimization": {
        "stage": 3,
        "offload_param": {"device": "none"},
        "offload_optimizer": {"device": "none"}
    },
    "bf16": {"enabled": true},
    "gradient_clipping": 1.0,
    "prescale_gradients": false,
    "wall_clock_breakdown": false
}

To reproduce use this script

# tested on
# - transformers==4.39.0.dev0
# - accelerate==0.28.0
# - deepspeed==0.14
# - torch==2.2.1
# - trl==nightly
# - fire==0.6.0

# to run:
#
# accelerate launch --config_file ./accelerate_config.yaml learning_rate_repro.py

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser
from transformers import TrainingArguments
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import torch

# to run on 4 gpus on FSDP (for DS change the config file): 
# 
# accelerate launch \
#     --num_processes 4 \
#     --main_process_port 29502 \
#     --config_file accelerate.yaml \
#     learning_rate_repro.py  \
#       --num_train_epochs 10 \
#       --output_dir './results' \
#       --per_device_train_batch_size 10 \
#       --lr_scheduler_type "linear" \
#       --learning_rate 1e-6 \
#       --logging_steps 1 \
#       --save_strategy 'no' \
#       --bf16

# to run throughput experiments with packing
# accelerate launch \
#     --num_processes 4 \
#     --main_process_port 29502 \
#     learning_rate_repro.py  \
#     --config_file accelerate.yaml \
# 	  --model_name ibm-granite/granite-7b-base \
#     --max_seq_len 4096 \
#     --num_train_epochs 1 \
#     --output_dir './results' \
#     --per_device_train_batch_size 8 \
#     --gradient_accumulation_steps 1 \
#     --include_tokens_per_second \
#     --gradient_checkpointing \
#     --lr_scheduler_type "linear" \
#     --learning_rate 1e-6 \
#     --logging_steps 1 \
#     --packing True \
#     --dataset_for_packing <path> \
#     --dataset_text_field ... \
#     --bf16 \
#     --max_steps 100 \
#     --save_strategy 'no'


def main(
    model_name: str = "mistralai/Mistral-7B-v0.1",
    max_seq_length=4096,
    num_data_samples=3000, # use a small number to clearly see the convergence
    load_model_dtype='bfloat16', # FSDP shared params will take 
    attn_implementation='sdpa',
    dataset_for_packing=None, # specify a path to a json for packign
    dataset_text_field=None, # specify the dataset text field for sft_trainer
):

    parser = HfArgumentParser(
        dataclass_types=TrainingArguments
    )
    training_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=getattr(torch, load_model_dtype), ## UPDATED
        attn_implementation=attn_implementation, ## UPDATED
    )

    # we set the max sequence length here
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, model_max_length=max_seq_length,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if not dataset_for_packing:
        # use the alpaca dataset
        dataset = load_dataset('tatsu-lab/alpaca', split='train')
        dataset = dataset.select(range(num_data_samples))

        def formatting_prompts_func(example):
            output_texts = []
            for i in range(len(example['instruction'])):
                text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
                output_texts.append(text)
            return output_texts

        # taken from https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/sft_trainer.py
        response_template_ids = tokenizer.encode(
            " ### Answer:", add_special_tokens=False
        )[2:]

        packing = False
        data_collator = DataCollatorForCompletionOnlyLM(
            response_template=response_template_ids,
            tokenizer=tokenizer, return_tensors='pt'
        )
        dataset_text_field = None
    else:
        # prepare a json dataset like this
        #  https://github.com/foundation-model-stack/fms-hf-tuning?tab=readme-ov-file#pre-process-the-jsonjsonl-dataset

        # for throughput
        packing = True
        data_collator = None
        formatting_prompts_func = None
        dataset_text_field = 'output'
        format_dataset = lambda example: { 
            f"{dataset_text_field}": example[f"{dataset_text_field}"]
            + tokenizer.eos_token
        }

        json_dataset = load_dataset("json", data_files=[dataset_for_packing])
        dataset = json_dataset["train"].map(format_dataset)

    trainer = SFTTrainer(
        model,
        args=training_args,
        train_dataset=dataset,
        formatting_func=formatting_prompts_func,
        dataset_text_field=dataset_text_field,
        max_seq_length=max_seq_length,
        data_collator=data_collator,
        packing=packing,
    )

    ## UPDATED
    # if trainer.is_fsdp_enabled:
    #     # while this can be set in accelerate, just make it explicit
    #     trainer.accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
    #     from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
    #     mp_policy = MixedPrecision(
    #         param_dtype=torch.bfloat16,
    #         buffer_dtype=torch.bfloat16,
    #         reduce_dtype=torch.bfloat16,
    #     )

    #     trainer.accelerator.state.fsdp_plugin.mixed_precision_policy = mp_policy

    # checkpoints will be saved every 100 steps as `pytorch.bin`
    trainer.train()

if __name__ == '__main__':
    import fire
    fire.Fire(main)
    main()

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

  1. Run training with some model (e.g. mistralai/Mistral-7B-v0.1) on multi-gpu training (say 4 nodes) for both FSDP and DeepSpeed (DS). Set FSDP to FULL_SHARD and DS to Zero3.
  2. Observe that the convergence for FSDP will be slower than DS, even though the learning rate is set exactly the same.

How slow is FSDP? Empircally, we found that if we increase the FSDP learning rate by the factor of the number of gpus used, then roughly be the same, as observed in the plots below.

image

This is unexpected as if the learning rate is set higher for the FSDP experiment, we expect the FSDP curve to converge faster. Our initial guess was that DS perhaps sums the gradients instead of averaging, but for zero3, we clearly see averaging being done if we trace into _avg_scatter_grads. In particular, the averaging will happen inside reduce_scatter_coalesced

Expected behavior

Expect that FSDP and DS, with all training parameters set the same, will converge at the same speed without having to scale the LRs.

@pacman100 @stas00 perhaps DS is doing an extra reduction somewhere that is not obvious?

@muellerzr
Copy link
Collaborator

What's interesting here from my perspective is ages ago NVIDIA came out with a report that you should modify your LR by the number of GPUs, and there's been mixed results as to why that should be done. So on one hand I'm pleasantly surprised to see this behavior? https://huggingface.co/docs/accelerate/concept_guides/performance#learning-rates

@fabianlim
Copy link
Contributor Author

What's interesting here from my perspective is ages ago NVIDIA came out with a report that you should modify your LR by the number of GPUs, and there's been mixed results as to why that should be done. So on one hand I'm pleasantly surprised to see this behavior? https://huggingface.co/docs/accelerate/concept_guides/performance#learning-rates

@muellerzr thanks for sharing the above. But In this case we are not comparing a multi-gpu with a single-gpu case, but rather two multi-gpu cases with the exact same number of gpus.

Below is a debug run im doing where I only have 1 data point to force the convergence in a deterministic and fast manner. But as you see, the LRs are stepped exactly the same. The first loss and grad norm is exactly the same, but even with the same LRs, the loss behaviors are different. I find this very perplexing.

DS

{'loss': 2.1219, 'grad_norm': 207.03898825364155, 'learning_rate': 9.998000000000002e-06, 'epoch': 0.02}
{'loss': 0.8048, 'grad_norm': 638.4725968208579, 'learning_rate': 9.996e-06, 'epoch': 0.04}
{'loss': 0.2986, 'grad_norm': 215.88360284394878, 'learning_rate': 9.994000000000001e-06, 'epoch': 0.06}

FSDP

{'loss': 2.1219, 'grad_norm': 207.0, 'learning_rate': 9.998000000000002e-06, 'epoch': 0.02}
{'loss': 0.8006, 'grad_norm': 636.0, 'learning_rate': 9.996e-06, 'epoch': 0.04}
{'loss': 0.2437, 'grad_norm': 362.0, 'learning_rate': 9.994000000000001e-06, 'epoch': 0.06}

@pacman100
Copy link
Contributor

Hello @fabianlim, can you provide a minimal reproducer for the above? Are you using mixed-precision and gradient accumulation?

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 5, 2024

@pacman100 I have updated the description to include the reproduction script. As you can see in the args, I do activate bf16 mixed precision but no gradient accumulation..

Also i have attached two runs here, for different LR's 1e-5 and 1e-6. Actually now I realize that FSDP's loss will drop evey slowly at some LR, but upon increasing it beyond some level, it appears to reach the same as DS's loss.

image

Some observations

  • DS consumes more memory (using nvidia-smi, DS consumes close to 80GB but FSDP stays between 60GB ~ 70GB). Maybe my settings are not 1:1 equivalent in this comparison
  • if you notice the logs, somehow FSDP grad norms are quantized, but the DS grad norms are pretty much in high precision.

Somehow I feel all this points to a precision issue. Somehow at very low LRs, if the updates are too small, then some precision losses cause problems with the FSDP parameter updates. Or conversely, DS is holding something at higher precision (hence the more memory consumption) and is able to function at higher LRs.

@stas00 any ideas from your side?

@lessw2020
Copy link

@fabianlim - in the FSDP mixed precision policy you are running here, what is the reduce_dtype?
If it's torch.bfloat16, then does running with it at torch.float32 resolve this difference?

@lessw2020
Copy link

in case this is helpful, this is what I mean - the reduce_dtype controls the precision of the communicated gradients:
Screenshot 2024-04-05 at 10 32 26 AM

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 7, 2024

@lessw2020 Thanks for the suggestion! I realize now that:

  • FSDP flat_params will take whatever dtype the model is loaded in. Previously I had loaded the model in bfloat16. When I change to load the model in float32 the convergence improved.
  • On the other hand DS should be internally upcasting weights loaded in bfloat16 to float32, as the optimizer step updates in f32.

To summarize my understanding:

Method Load Model Flat Params FwdBck Optim Params
FSDP bf16 bf16 bf16 bf16
FSDP (bf16-MP) fp32 fp32 bf16 fp32
DS (bf16-MP) bf16 fp32 bf16 fp32

I updated the reproduction script to allow the dtype that the model is loaded in to be selected, and also to explicitly set the FSDP mp_policy to bfloat16. In the experiments below, the dotted curve corresponds to the second line of the able above. We can see that the convergence issues go away (at the cost of extra mem) when we load in float32.

image

To conclude:

  • convergence differences we experienced should not have been due to reduction comms or grad accum, but model and optimizer precisions.
  • misunderstood that DS should have kept optimizer precisions same as model, but some upcasting happening inside.

@raghukiran1224
Copy link

@muellerzr is it worth writing a short blog post on this? I think it is very useful for folks who use both stacks.

@muellerzr
Copy link
Collaborator

A blog certainly, a portion in the docs is even better, probably under a new concept guide

@fabianlim
Copy link
Contributor Author

@muellerzr I have started to draft a concept guide here. Just putting the link here to solicitate early with you, but its still very rough and I will be adding more to it https://github.com/fabianlim/accelerate/blob/update_docs_fsdp_deepspeed/docs/source/concept_guides/fsdp_and_deepspeed.md

While preparing it I tried to reference the other existing concept guides.

Currently it only has two sections

  1. similarities / differences in how to set the accelerate configs, which I think is good to lay out the different flags and show how they map to each other. As this was one of the pain points i have coming to DS from FSDP, as to where / how to set what.
  2. differences in how precisions are handled in the differen frameworks. This section points to the discussion in this issue

Any suggestions for more sections?

Tried to say high level as much as possible and not go too much detail into code as this may change over time.

@minjiazhang
Copy link

Hi @fabianlim, can you check the per-gpu-batch-size on each rank when using DS ZeRO3 and FSDP? The fact that ZeRO3 consumes more memory than FSDP makes me think perhaps the way these two frameworks decide the per-gpu-batch-size is different. You can also explicitly set the per-gpu-batch-size for DS by changing "auto" in "train_micro_batch_size_per_gpu": "auto" to a specific number (e.g., 16). You may also want to check how to set a fixed per-gpu-batch-size for FSDP.

@fabianlim
Copy link
Contributor Author

Hi @fabianlim, can you check the per-gpu-batch-size on each rank when using DS ZeRO3 and FSDP? The fact that ZeRO3 consumes more memory than FSDP makes me think perhaps the way these two frameworks decide the per-gpu-batch-size is different. You can also explicitly set the per-gpu-batch-size for DS by changing "auto" in "train_micro_batch_size_per_gpu": "auto" to a specific number (e.g., 16). You may also want to check how to set a fixed per-gpu-batch-size for FSDP.

@minjiazhang they are the same, as you can see I set the train_micro_batch_size_per_gpu: auto exactly the way you said it #2624 (comment)

The reason for the extra memory, is because in DS because local flat params, optimizer params, and gradients, are kept in fp32, see here, whereas in FSDP it follows the dtype of the model params.

@muellerzr
Copy link
Collaborator

@fabianlim I think that alone is a great improvement to the docs already! Would you like to open the PR? :)

@fabianlim
Copy link
Contributor Author

@muellerzr sure will do, give me a day or two to proof read it again!

@minjiazhang
Copy link

Thank you, Fabian! That makes sense. Then I agree the casting issue of bf16 to the master fp32 weight is likely the one that caused the memory consumption difference and the convergence gap.

@stas00
Copy link
Contributor

stas00 commented Apr 10, 2024

After reading this thread, basically the problem comes from FSDP keeping the training regime to the one of the initial model weights, regardless of the training regime specified by the user (w/ or w/o AMP).

Therefore if the model was loaded in half-precision - it should probably raise an exception if a user is trying to apply AMP to it.

I think this should be done on the code level and not documentation - nobody reads documentation. But this can be easily fixed in the code.

i.e.:

  • Allow 16bit model with direct bf16 training
  • Either don't allow 16bit model with AMP, or automatically upcast it to 32bit - probably the latter because these days most published models checkpoints are in 16bit.

Summary: Using AMP with 16bit model should be clearly an assertion

@fabianlim fabianlim reopened this Apr 11, 2024
@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 11, 2024

Reopening issue again since we may be adding more code.

@stas00 thanks for the suggestion. I took a stab at doing this after the FSDP wrap; I havnt tested it yet because im in the middle of something else. Idea is to look only at the FSDP flat_params, and only upcast those. Only thing Im not that sure of if its possible to have a flat_param on the meta device, when fsdp_cpu_ram_efficient_loading is activated. I need to check and test. FYI: @muellerzr

Maybe we can both code updates and documentation; the docs also serve to explain other similarities/ differences in FSDP/DS, besides this upcasting issue. I will raise a PR soon.

@stas00
Copy link
Contributor

stas00 commented Apr 11, 2024

Conceptually your proposal looks good, @fabianlim - thank you for leading this important investigation!

Here is a bunch of suggestions/questions/concerns:

  • instead of doing one param at a time - won't it be simpler to do any on model params and if they are bf16 or fp16 then upcast the whole model to fp32 at once? or may be it should be all instead of any, but I think some models may have some params in fp32 (often buffers). It's possible that your cautious approach is the best after all.

  • what about buffers? These are critical too - especially when those are used in positional embeddings. Those too need to be upcast I think to get the correct behavior. Upcasting the whole model at once as suggested above will handle buffers as well I believe (but double check).

  • I'm worried that people will not pay attention that the model was upcast (warnings are usually either not read or missed because there are too many of them in most applications) and wonder why there is a lot memory being used? Especially if they have calculated and planned the memory usage.

  • Saving the model will now be unexpected as well - as a 16bit model will end up being saved as a 32bit model.

I know it'd be disruptive/BC-breaking but perhaps asserting and guiding the user to do the right thing would be a cleaner solution in the long term? This would work better for any custom loop where the user can correct their code. But solutions like HF Trainer would still require done-on-behalf-of-user approach, since the user can't control the dtype of the model.

Not sure what the best solution is and perhaps it might be a good idea to involve HF Transformers maintainers to make sure this solution is done across the board including HF Trainer.

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 11, 2024

@stas00 thanks for the thoughts! these are very helpful!

  • yes I was tryign to be cautious. I thought it best to follow how deepspeed was integrated into Trainer. Im aware that DS only casts the flat params (but I need to check more carefully), so I tried to mimick that. Also, bearing in midn that FSDP has an ignore_modules control, doing a model.to will also affect the modules that were told to be ignored, so I feel that is not the desired effect. The flat_params on the other hand do only contain params that are not ignored.
  • i was on the fence about buffers. While the buffer precisions do affect the fwd/backward activations, I do not actually believe they affect the convergence. This is because I believe that the convergence issue was mostly due to the params being in half precision, and when the LR is too low, the updates are below the precision limits. Still good to check - i will check closely how DS handles buffers. DS might not be upcasting buffers, because it knows that these are not involved in the optimizer step. Will check
  • Yes it is true that most people do not pay attention to what was upcast. Even when I was comparing FSDP and DS (before I realized that DS had upcasted the flat_params), I was wondering where the extra memory was coming from. So maybe we have to warn the user on the exact params that was upcasted, kind of like the missing_keys warning.
  • good point on saving the model! I need to check this. For full state dict, the model params are summoned to rank 0 when saved, so they should be saved in the reduce_type, but for sharded state dict, im not sure what the behavior is. Need to check

Yes I think I will try to propose something, and carefully document all the implications. Your thoughts are very helpful since you have pointed out certain things I might have overlooked. Then we should get the trainer maintainers to review.

@stas00
Copy link
Contributor

stas00 commented Apr 11, 2024

Your follow up notes resonate well, @fabianlim

The other thing I forgot to mention is to also think about how fp8 trainings will happen, as I'd imagine they would be AMP as well, but a different dtype. So an fp8 training is likely to use bf16/fp16 instead of fp32 for the high-precision. I'm aware that this mixed precision training isn't mainstream yet since not many get good results, but all the functionality is there, so it should be easy to consider that functionality right away.

@fabianlim
Copy link
Contributor Author

@stas00 thanks for your suggestion! Got it, will consider fp8 training as well.

@muellerzr
Copy link
Collaborator

muellerzr commented Apr 11, 2024

Btw we’re working with the nvidia team rn to enable FSDP + fp8 currently, should be out by next release (and merged hopefully within a week). See #2655

@fabianlim
Copy link
Contributor Author

@muellerzr thanks for pointing it out, i see that the changes will be necessary to delay the autocast after the FSDP wrap. Can #2655 be tested at this point, just so I can get a sense of it?

@muellerzr
Copy link
Collaborator

muellerzr commented Apr 12, 2024

Working on that soon @fabianlim :) Will update when its working enough to toy with on FSDP

@fabianlim
Copy link
Contributor Author

@muellerzr btw the command line arguments here almost never get updated. It seems it is not integrated with Sphinx https://huggingface.co/docs/accelerate/en/package_reference/cli#accelerate-launch. I guess we should also manually update this reference?

@fabianlim
Copy link
Contributor Author

I raised an inital PR. I have implemented and tested the casting logic requested by @stas00 . I had to use a few flags that are internal to the pytorch FSDP wrapper, that was the downside of it.

I have also updated the documentation along with it.

@pacman100
Copy link
Contributor

Hello @fabianlim and everyone part of this discussion,

Thank you @fabianlim for the updated reproducer code. We did mention in several issues that the base model has to be loaded in the default FP32 dtype with AMP for full finetuning when all the parameters are trainable, basically all trainable parmas need to be in FP32 when using AMP. Please refer this PR to enable FA2 with the model loaded in FP32 for proper convergence huggingface/transformers#28142 as well as the issue comment tagged therein huggingface/transformers#26498 (comment).

Thank you for the PR with the detailed concept guide ✨ and the logic to upcast the flat parameters to FP32.

@pacman100
Copy link
Contributor

In terms of the Trainer being inline here, the changes from the above PR would be used by the Trainer and as such no changes on its side are required.

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 16, 2024

@pacman100 thanks for your comments! Yes I was aware that it has been documented that when using AMP, one should load the model in fp32.

However because nothing is really stopping the user to load in low precision while using mixed-precision, @stas00's suggested to avoid the difference in behavior of DS and FSDP in this (wrong-usage) scenario, just upcast it automatically and inform the user we are doing this.

Thanks for pointing out huggingface/transformers#28142. Understand that this updates a warning to inform user to load in float32 when using MP to train with FA2. Interestingly, I updated the repro script to simulate a user that ignores the warning (loads in bf16) but trains with MP. As you can see, both DS, and the new upcasting fix for FSDP), will result in correct convergence.

This picture repeats the FSDP and DS experiments with FA2. I verified that the models were indeed loaded with FA2 and 16 bit weights.

image

I added an attn_implementation flag in the repro script

accelerate launch \
    --num_processes 4 \
    --main_process_port $MASTER_PORT \
    --config_file ./fixtures/accelerate_fsdp.yaml \
    learning_rate_repro.py  \
      --num_train_epochs 10 \
      --output_dir './results' \
      --per_device_train_batch_size 10 \
      --lr_scheduler_type "linear" \
      --learning_rate 1e-6 \
      --attn_implementation flash_attention_2 \  # <- NEW
      --logging_steps 1 \
      --save_strategy 'no' \
      --bf16 \

@stas00
Copy link
Contributor

stas00 commented Apr 29, 2024

@pacman100, relying on a user to remember to read some specific docs or even to be aware that they exist is futile.

Do you not want users to succeed and continue using Accelerate and not switch to another framework?

At the very least I think you'd want to assert (not warn) if AMP is used with non-FP32 weights and give a user a way to override the assert via a flag or an env var if the really want to waste more compute and lose time.

I fail to see how not preventing invalid use of Accelerate/FSDP is good to anybody. Perhaps I'm missing something?

This problem is going to affect more and more people as more and more models are distributed in half-precision.

@muellerzr
Copy link
Collaborator

I can be open to allowing a flag, else raising a ValueError if not enabled. As otherwise yes this leads to terrible performance for models and many confusing behaviors (as mentioned here)

@stas00
Copy link
Contributor

stas00 commented Apr 29, 2024

I commented on this before getting to the new PR #2674 so I'm not sure if my comment is still relevant - doesn't Accelerate now do the right thing and upcast on behalf of the user automatically?

and if so shouldn't this issue be closed then?

@muellerzr
Copy link
Collaborator

Yes indeed you are right, that PR included upcasting so this can be closed. CC @fabianlim :)

@fabianlim
Copy link
Contributor Author

@muellerzr @stas00 yes that is correct accelerate will now do the upcasting. Closing this issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants