Skip to content

Conversation

@zhangtao0408
Copy link

@zhangtao0408 zhangtao0408 commented Dec 27, 2025

What does this PR do?

Moving pos_embed computation from CPU back to NPU results in a 1.07x speedup in Flux.1's end-to-end latency.

Since CANN updated to 8.3.RC1, the bad performance of torch.repeat_interleave operator has been optimized. Results shown below:

Model Device Resolution Steps e2e latency
FLUX.1-DEV npu 1024 x 1024 50 25.54
FLUX.1-DEV cpu 1024 x 1024 50 27.41
FLUX.2-DEV npu 1024 x 1024 28 101.49
FLUX.2-DEV cpu 1024 x 1024 28 118.22
LongCat-Image npu 768x1344 28 31.87
LongCat-Image cpu 768x1344 28 36.19
Ovis-Image npu 1024 x 1024 28 27.16
Ovis-Image cpu 1024 x 1024 28 40.47

Tested Hardware

Ascend 910B3

Repro Code

1. FLUX.1-dev

import time

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("npu")

prompt = "A cat holding a sign that says hello world"

# Warmup
_ = pipe(prompt, height=1024, width=1024, guidance_scale=3.5, num_inference_steps=2, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0))

# Inference
start_time = time.time()

image = pipe(prompt, height=1024, width=1024, guidance_scale=3.5, num_inference_steps=2, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0)).images[0]
image.save("flux-dev.png")

end_time = time.time()
print(f"Time: {end_time - start_time:.2f}s")

2. FLUX.2-DEV

import time

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

from diffusers import Flux2Pipeline
pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
pipe.enable_group_offload(
    onload_device=torch.device("npu"),
    offload_device=torch.device("cpu"),
    offload_type="leaf_level",
    use_stream=True
)

prompt = "A cat holding a sign that says hello world"

# Warmup
_ = pipe(prompt, height=1024, width=1024, guidance_scale=3.5, num_inference_steps=2, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0))

# Inference
start_time = time.time()

image = pipe(prompt, height=1024, width=1024, guidance_scale=3.5, num_inference_steps=2, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0)).images[0]
image.save("flux.2-dev.png")

end_time = time.time()
print(f"Time: {end_time - start_time:.2f}s")

3. LongCat-Image

import time

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

from diffusers import LongCatImagePipeline
pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image/", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。'

# WARMUP
image = pipe(prompt, height=768, width=1344, guidance_scale=4.0, num_inference_steps=2, num_images_per_prompt=1, generator=torch.Generator("cpu").manual_seed(43), enable_cfg_renorm=True, enable_prompt_rewrite=True).images[0]

# Inference
start_time = time.time()

image = pipe(prompt, height=768, width=1344, guidance_scale=4.0, num_inference_steps=28, num_images_per_prompt=1, generator=torch.Generator("cpu").manual_seed(43), enable_cfg_renorm=True, enable_prompt_rewrite=True).images[0]

image.save("longcat.png")

end_time = time.time()
print(f"Time: {end_time - start_time:.2f}s")

4. Ovis-Image

import time

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

from diffusers import OvisImagePipeline
pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A creative 3D artistic render where the text \"OVIS-IMAGE\" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist's canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail."

# Warmup
image = pipe(prompt, negative_prompt="", num_inference_steps=2, guidance_scale=5.0).images[0]

# Inference
start_time = time.time()
image = pipe(prompt, negative_prompt="", num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("ovis.png")

end_time = time.time()
print(f"Time: {end_time - start_time:.2f}s")

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@zhangtao0408
Copy link
Author

@sayakpaul Please review this pr, thanks.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We should also take care of others that follow this pattern. For example:

if is_torch_npu_available():
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
else:
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)

@zhangtao0408
Copy link
Author

Thanks! We should also take care of others that follow this pattern. For example:

if is_torch_npu_available():
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
else:
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)

Thanks for your suggestion, I tested the FLUX.2-Dev, LongCat-Image, and Ovis-Image models on the Ascend platform. Their performance improved after switching the position embedding calculation from CPU back to the NPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants