Skip to content

Contrib: S3Diff one-step 4x super-resolution#149

Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/s3diff
Open

Contrib: S3Diff one-step 4x super-resolution#149
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/s3diff

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

Adds S3Diff (ECCV 2024) one-step 4x super-resolution to contrib. S3Diff uses SD-Turbo's UNet with degradation-guided dynamic LoRA modulation: a DEResNet encoder estimates input degradation and produces per-layer [rank, rank] modulation matrices injected between LoRA A/B weights via einsum.

  • Uses torch_neuronx.trace() (model is ~2 GB, no tensor parallelism needed)
  • Five compiled components: DEResNet, CLIP text encoder, VAE encoder (with LoRA), UNet (with LoRA), VAE decoder
  • Validated on trn2.3xlarge, SDK 2.29

Benchmark Results (128x128 -> 512x512)

Metric Value
Warm generation time 0.544s
Throughput ~1.8 img/s
Total compile time ~21 min
CPU baseline 11.53s
Speedup vs CPU ~21x

Implementation Notes

  • LoRA components (VAE encoder, UNet, VAE decoder) use --model-type=unet-inference instead of --auto-cast=matmult because the small einsum modulation operations produce NaN under BF16 auto-casting.
  • DEResNet and CLIP text encoder (no LoRA) use --auto-cast=matmult normally.
  • Modulation MLPs are tiny and run on CPU.

Files

  • src/modeling_s3diff.py — Full pipeline: DEResNet, LoRA forward, wrappers, S3DiffNeuronPipeline class
  • src/generate_s3diff.py — CLI generation script with weight download support
  • test/integration/test_model.py — 3 integration tests (smoke, SR size, timing)
  • README.md — Full documentation

Tests

All 3 integration tests pass on trn2.3xlarge:

  • test_smoke_pipeline_loads — PASSED
  • test_sr_produces_correct_size — PASSED (512x512 output, pixel std=19.9)
  • test_warm_generation_time — PASSED (0.544s < 2s threshold)

Weights

  • SD-Turbo: stabilityai/sd-turbo
  • S3Diff LoRA: zhangap/S3Diff
  • DEResNet: from S3Diff GitHub repo

S3Diff (ECCV 2024) performs degradation-guided 4x super-resolution
in a single denoising step using SD-Turbo with dynamic LoRA modulation.
A DEResNet encoder estimates degradation and produces per-layer
modulation matrices injected between LoRA A/B weights.

Uses torch_neuronx.trace() (no TP needed, model is ~2 GB).
Validated on trn2.3xlarge, SDK 2.29: 0.544s/image, ~21x CPU speedup.

LoRA components use --model-type=unet-inference to avoid NaN from
--auto-cast=matmult on small einsum operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant