Skip to content

Commit 2eef5d0

Browse files
committed
feat: adapt res samplers for flow models for eta > 0
1 parent a564fdf commit 2eef5d0

1 file changed

Lines changed: 27 additions & 6 deletions

File tree

src/denoiser.hpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,18 @@ static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from,
808808
return {sigma_down, sigma_up, alpha_scale};
809809
}
810810

811+
static std::tuple<float, float, float> get_ancestral_step(float sigma_from,
812+
float sigma_to,
813+
float eta,
814+
bool is_flow_denoiser) {
815+
if (is_flow_denoiser) {
816+
return get_ancestral_step_flow(sigma_from, sigma_to, eta);
817+
} else {
818+
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
819+
return {sigma_down, sigma_up, 1.0f};
820+
}
821+
}
822+
811823
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
812824
sd::Tensor<float> x,
813825
const std::vector<float>& sigmas,
@@ -1148,6 +1160,7 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
11481160
sd::Tensor<float> x,
11491161
const std::vector<float>& sigmas,
11501162
std::shared_ptr<RNG> rng,
1163+
bool is_flow_denoiser,
11511164
float eta) {
11521165
sd::Tensor<float> old_denoised = x;
11531166
bool have_old_sigma = false;
@@ -1179,7 +1192,8 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
11791192

11801193
float sigma_from = sigmas[i];
11811194
float sigma_to = sigmas[i + 1];
1182-
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
1195+
1196+
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
11831197

11841198
if (sigma_down == 0.0f || !have_old_sigma) {
11851199
x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from);
@@ -1206,7 +1220,10 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
12061220
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised);
12071221
}
12081222

1209-
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
1223+
if (sigma_to > 0.0f && sigma_up > 0.0f) {
1224+
if (is_flow_denoiser) {
1225+
x *= alpha_scale;
1226+
}
12101227
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
12111228
}
12121229

@@ -1221,6 +1238,7 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
12211238
sd::Tensor<float> x,
12221239
const std::vector<float>& sigmas,
12231240
std::shared_ptr<RNG> rng,
1241+
bool is_flow_denoiser,
12241242
float eta) {
12251243
const float c2 = 0.5f;
12261244
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
@@ -1249,7 +1267,7 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
12491267
}
12501268
sd::Tensor<float> denoised = std::move(denoised_opt);
12511269

1252-
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
1270+
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
12531271

12541272
sd::Tensor<float> x0 = x;
12551273
if (sigma_down == 0.0f || sigma_from == 0.0f) {
@@ -1278,7 +1296,10 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
12781296
x = x0 + h * (b1 * eps1 + b2 * eps2);
12791297
}
12801298

1281-
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
1299+
if (sigma_to > 0.0f && sigma_up > 0.0f) {
1300+
if (is_flow_denoiser) {
1301+
x *= alpha_scale;
1302+
}
12821303
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
12831304
}
12841305
}
@@ -1577,9 +1598,9 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
15771598
case IPNDM_V_SAMPLE_METHOD:
15781599
return sample_ipndm_v(model, std::move(x), sigmas);
15791600
case RES_MULTISTEP_SAMPLE_METHOD:
1580-
return sample_res_multistep(model, std::move(x), sigmas, rng, eta);
1601+
return sample_res_multistep(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
15811602
case RES_2S_SAMPLE_METHOD:
1582-
return sample_res_2s(model, std::move(x), sigmas, rng, eta);
1603+
return sample_res_2s(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
15831604
case ER_SDE_SAMPLE_METHOD:
15841605
return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
15851606
case DDIM_TRAILING_SAMPLE_METHOD:

0 commit comments

Comments
 (0)