Skip to content

[quantization] Save gpu memory#629

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:save_vram
Apr 15, 2026
Merged

[quantization] Save gpu memory#629
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:save_vram

Conversation

@stamalakhov
Copy link
Copy Markdown
Contributor

@stamalakhov stamalakhov commented Apr 14, 2026

This PR uses gpu memory in GPTQ algorithm only for inference to reduce gpu memory usage.

It will make it possible to use large number of samples on a gpu with constrained memory.

Sample run on 256 samples for TinyLlama/TinyLlama-1.1B-Chat-v1.0' on 8Gb GPU
Namespace(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0', device='cuda', dtype='float32', seed=42, trust_remote_code=False, hf_token=None, no_tqdm=False, no_GPTQ=False, no_spinquant=True, no_PTQ=False, save_circle_to_folder=None, save_layers_to_folder=None, cache_dir='/mnt/storage/transformers_cache', nsamples_for_qcalibration=256, linear_weight_bits=4, gptq_mse='mse', max_seq_len=2048, calibrate_seq_len=2048, embedding_weight_bits=8, lm_head_weight_bits=4, eval_tasks=None, sensitivity_path=None)
=== Config ===
Model            : TinyLlama/TinyLlama-1.1B-Chat-v1.0
Device           : cuda
DType            : float32

Loading FP model …
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 109.01it/s]
Skipping SpinQuant preprocessing …

Calculating original perplexities …
Token indices sequence length is longer than the specified maximum sequence length for this model (341469 > 2048). Running this sequence through the model will result in indexing errors
PPL:  99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 166/167 [02:26<00:00,  1.13it/s]

┌── Wikitext-2 test perplexity ─────────────
│ FP32 :     7.97
└───────────────────────────────────────────
Applying GPTQ …
Quantizing layers: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [16:18<00:00, 44.47s/layer]
Wrapping layers with PTQWrapper …                                                                                                                                                                                          
Calibrating PTQ obeservers…
  0%|                                                                                                                                                                                              | 0/256 [00:00<?, ?it/s]`use_return_dict` is deprecated! Use `return_dict` instead!
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [06:54<00:00,  1.62s/it]

Calculating perplexities …
PPL:  99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 166/167 [04:01<00:01,  1.45s/it]

┌── Wikitext-2 test perplexity ─────────────
│ int16 :     8.66
└───────────────────────────────────────────
./ccex test --include-internal -k quantization.algorithm.test_gptq
RUN unit tests with -k quantization.algorithm.test_gptq ...
test_gptq_config_validate_rejects_non_positive_weight_bits_override (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_gptq_config_validate_weight_bits_overrides (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_groupwise_conv1d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_groupwise_conv2d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_model (quantization.algorithm.test_gptq.GPTQTest) ... Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 75/75 [00:00<00:00, 7095.52it/s]
ok
test_net (quantization.algorithm.test_gptq.GPTQTest) ... No specialized wrapper found for ModuleList; applying recursive wrapping.
ok
test_net_on_zero_inputs (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv1d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv1d_with_logits (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv2d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv2d_on_zero_inputs (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv2d_with_logits (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv3d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv3d_on_zero_inputs (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_normconv3d_with_logits (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_paddednormconv2d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_paddednormconv3d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_resolve_weight_bits_priority (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_transposed_conv2d (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_transposed_conv2d_with_logits (quantization.algorithm.test_gptq.GPTQTest) ... ok
test_weight_bits_overrides_are_applied_per_module (quantization.algorithm.test_gptq.GPTQTest) ... ok

----------------------------------------------------------------------
Ran 21 tests in 72.134s

OK

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Apr 14, 2026
@stamalakhov stamalakhov force-pushed the save_vram branch 2 times, most recently from 0130e7e to 2f97b87 Compare April 14, 2026 11:36
@stamalakhov stamalakhov requested a review from mhs4670go April 14, 2026 11:40
unit="batch",
disable=not gptq_conf.show_progress,
):
device = next(model.parameters()).device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) Getting device is recomputed inside the inner loop. It could be moved outside the loop to avoid redundant work and improve readability.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines +126 to +132
self.cache_kwargs[k].append(
v.cpu()
if isinstance(v, torch.Tensor)
else (v[0].cpu(), v[1].cpu())
if isinstance(v, tuple)
else None
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to preserve kwargs as-is, but this patch now converts only tensors and 2-element tuples, and replaces everything else with None. That can change input semantics.

I think this should be implemented with a recursive “move tensors only” helper that preserves the original container structure.

# let's add this function to tico/utils/utils.py
def move_to_device(obj, device):
    """
    Recursively move tensors inside a nested structure to the given device.
    Non-tensor objects are preserved as-is.
    """
    if isinstance(obj, torch.Tensor):
        return obj.to(device)

    elif isinstance(obj, tuple):
        return tuple(move_to_device(x, device) for x in obj)

    elif isinstance(obj, list):
        return [move_to_device(x, device) for x in obj]

    elif isinstance(obj, dict):
        return {k: move_to_device(v, device) for k, v in obj.items()}

    # preserve everything else (bool, int, None, custom objects, etc.)
    return obj

def move_to_cpu(obj):
    return move_to_device(obj, "cpu")

Then, we can just call it like this.

# after
self.cache_kwargs[k].append(move_to_cpu(v))

cache_kwargs_batch = gather_single_batch_from_dict(self.cache_kwargs, batch_idx)
cache_kwargs_batch = move_to_device(cache_kwargs_batch, device)

cache_args_batch = gather_single_batch_from_list(self.cache_args, batch_idx)
cache_args_batch = move_to_device(cache_args_batch, device)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Understood. Thank you!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@stamalakhov stamalakhov force-pushed the save_vram branch 2 times, most recently from ce1303e to 1e68dda Compare April 15, 2026 06:42
@stamalakhov stamalakhov requested a review from mhs4670go April 15, 2026 06:50
return padding


def move_to_device(obj, device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving this to tico/utils/utils.py? This function seems to be used in other places as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

This PR uses gpu memory in GPTQ algorithm only for inference to reduce gpu memory usage.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
Copy link
Copy Markdown
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit ba2c5d1 into Samsung:main Apr 15, 2026
7 checks passed
@stamalakhov stamalakhov deleted the save_vram branch April 15, 2026 08:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants