You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import time
for name in qlayers:
logger.info(name)
start = time.time()
quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to("cpu")
layers[name], scale, zero, g_idx = layers[name].to("cpu"), scale.to("cpu"), zero.to("cpu"), g_idx.to("cpu")
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
print(f"Time to pack {name}: {time.time() - start}")
This has timings:
Time for transformer.blocks.0.attn.Wqkv: 0.24
Time for transformer.blocks.0.attn.out_proj: 0.08
Time for transformer.blocks.0.ffn.down_proj: 0.25
Time for transformer.blocks.0.ffn.up_proj: 91.95
However, if I run it in parallel (which I think preserves everything) as:
from concurrent.futures import ThreadPoolExecutor
import time
def pack_layer(name):
logger.info(name)
start = time.time()
quantizers[name], scale, zero, g_idx = quantizers[name]
layer_device = qlayers[name].device
qlayers[name].to("cpu")
layers[name], scale, zero, g_idx = layers[name].to("cpu"), scale.to("cpu"), zero.to("cpu"), g_idx.to("cpu")
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
print(f"Time for {name}: {time.time() - start}")
with ThreadPoolExecutor() as executor:
executor.map(pack_layer, qlayers.keys())
The timings are
Time for transformer.blocks.0.attn.Wqkv: 0.21
Time for transformer.blocks.0.attn.out_proj: 0.10
Time for transformer.blocks.0.ffn.down_proj: 0.22
Time for transformer.blocks.0.ffn.up_proj: 3.64
Expected behavior
Whats strange is that the packing time is so long, taking ~ 90 seconds to pack the up projection. Whats even stranger is that the individual packing times per layer are lower when run in parallel (not just the overall time). Ie when run sequentially the time to pack the up proj is 90 seconds but this goes down to 3 seconds when running packing each layer in parallel.
The text was updated successfully, but these errors were encountered:
System Info
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction (minimal, reproducible, runnable)
I'm quantizing opt-350m using gptq. The actual quantization is fast but then packing layers is slow. The code to quantize the models is as follows:
I added some timings to GPTQ packing (
optimum/optimum/gptq/quantizer.py
Line 614 in 5c803db
This has timings:
Time for transformer.blocks.0.attn.Wqkv: 0.24
Time for transformer.blocks.0.attn.out_proj: 0.08
Time for transformer.blocks.0.ffn.down_proj: 0.25
Time for transformer.blocks.0.ffn.up_proj: 91.95
However, if I run it in parallel (which I think preserves everything) as:
The timings are
Time for transformer.blocks.0.attn.Wqkv: 0.21
Time for transformer.blocks.0.attn.out_proj: 0.10
Time for transformer.blocks.0.ffn.down_proj: 0.22
Time for transformer.blocks.0.ffn.up_proj: 3.64
Expected behavior
Whats strange is that the packing time is so long, taking ~ 90 seconds to pack the up projection. Whats even stranger is that the individual packing times per layer are lower when run in parallel (not just the overall time). Ie when run sequentially the time to pack the up proj is 90 seconds but this goes down to 3 seconds when running packing each layer in parallel.
The text was updated successfully, but these errors were encountered: