Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions api/experimentation/migrations/0009_add_rollout_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.2.14 on 2026-06-19 09:59

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("experimentation", "0008_experiment_results"),
("segments", "0030_add_default_to_segment_version"),
]

operations = [
migrations.AddField(
model_name="experiment",
name="rollout_segment",
field=models.OneToOneField(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="experiment_rollout",
to="segments.segment",
),
),
]
7 changes: 7 additions & 0 deletions api/experimentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class Experiment(LifecycleModelMixin, SoftDeleteExportableModel): # type: ignor
updated_at = models.DateTimeField(auto_now=True)
started_at = models.DateTimeField(null=True, blank=True)
ended_at = models.DateTimeField(null=True, blank=True)
rollout_segment = models.OneToOneField(
"segments.Segment",
on_delete=models.SET_NULL,
related_name="experiment_rollout",
null=True,
blank=True,
)

class Meta:
constraints = [
Expand Down
66 changes: 66 additions & 0 deletions api/experimentation/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db.models import QuerySet
from rest_framework import serializers

from core.dataclasses import AuthorData
from environments.models import Environment
from experimentation.dataclasses import WarehouseEventStats
from experimentation.metric_definitions import validate_metric_definition
Expand All @@ -18,14 +19,21 @@
WarehouseConnection,
WarehouseType,
)
from experimentation.services import create_experiment_rollout
from experimentation.types import (
SNOWFLAKE_DEFAULTS,
MetricExperimentResult,
SnowflakeConfig,
)
from features.feature_states.serializers import (
FeatureValueSerializer,
MultivariateValueSerializer,
validate_multivariate_state_values,
)
from features.feature_types import MULTIVARIATE
from features.models import Feature
from features.multivariate.serializers import NestedMultivariateFeatureOptionSerializer
from features.versioning.dataclasses import MultivariateValueChangeSet


class WarehouseConnectionSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
Expand Down Expand Up @@ -207,6 +215,35 @@ class ExperimentMetricInlineSerializer(serializers.Serializer): # type: ignore[
expected_direction = serializers.ChoiceField(choices=ExpectedDirection.choices)


class ExperimentRolloutSerializer(serializers.Serializer): # type: ignore[type-arg]
enabled = serializers.BooleanField(required=True)
rollout_percentage = serializers.FloatField(
required=True, min_value=0, max_value=100
)
feature_state_value = FeatureValueSerializer(required=True)
multivariate_feature_state_values = MultivariateValueSerializer(
many=True, required=False
)

@staticmethod
def to_service_kwargs(data: dict[str, Any], request: Any) -> dict[str, Any]:
value = data["feature_state_value"]
return {
"enabled": data["enabled"],
"rollout_percentage": data["rollout_percentage"],
"feature_state_value": value["value"],
"value_type": value["type"],
"multivariate_values": [
MultivariateValueChangeSet(
multivariate_feature_option_id=mv["multivariate_feature_option"],
percentage_allocation=mv["percentage_allocation"],
)
for mv in data.get("multivariate_feature_state_values", [])
],
"author": AuthorData.from_request(request),
}


class ExperimentSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
# Annotated with the common base type so ExperimentListSerializer can
# override the field with a read-only representation.
Expand All @@ -215,6 +252,7 @@ class ExperimentSerializer(serializers.ModelSerializer): # type: ignore[type-ar
required=False,
write_only=True,
)
experiment_rollout = ExperimentRolloutSerializer(required=False, write_only=True)

class Meta:
model = Experiment
Expand All @@ -225,6 +263,7 @@ class Meta:
"hypothesis",
"status",
"metrics",
"experiment_rollout",
"created_at",
"updated_at",
"started_at",
Expand Down Expand Up @@ -260,9 +299,28 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
raise serializers.ValidationError(
{"metrics": "Cannot change the metrics of an existing experiment."}
)
if self.instance is not None and "experiment_rollout" in attrs:
raise serializers.ValidationError(
{
"experiment_rollout": (
"Cannot change the rollout via this endpoint; "
"use the rollout endpoint instead."
)
}
)
self._validate_metrics(attrs.get("metrics") or [])
self._validate_rollout(attrs)
return attrs

def _validate_rollout(self, attrs: dict[str, Any]) -> None:
rollout = attrs.get("experiment_rollout")
feature = attrs.get("feature")
if not rollout or feature is None:
return
validate_multivariate_state_values(
feature, rollout.get("multivariate_feature_state_values", [])
)

def _validate_metrics(self, metrics: list[dict[str, Any]]) -> None:
metric_ids = [entry["metric"].id for entry in metrics]
if len(metric_ids) != len(set(metric_ids)):
Expand All @@ -272,6 +330,7 @@ def _validate_metrics(self, metrics: list[dict[str, Any]]) -> None:

def create(self, validated_data: dict[str, Any]) -> Experiment:
metrics: list[dict[str, Any]] = validated_data.pop("metrics", [])
rollout: dict[str, Any] | None = validated_data.pop("experiment_rollout", None)
with transaction.atomic():
experiment: Experiment = super().create(validated_data)
ExperimentMetric.objects.bulk_create(
Expand All @@ -282,6 +341,13 @@ def create(self, validated_data: dict[str, Any]) -> Experiment:
)
for entry in metrics
)
if rollout is not None:
create_experiment_rollout(
experiment,
**ExperimentRolloutSerializer.to_service_kwargs(
rollout, self.context["request"]
),
)
return experiment


Expand Down
87 changes: 87 additions & 0 deletions api/experimentation/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from django.conf import settings
from django.db.models import Q
from django.utils import timezone
from flag_engine.segments.constants import PERCENTAGE_SPLIT
Comment thread
gagantrivedi marked this conversation as resolved.
from rest_framework.exceptions import ValidationError

from audit.models import AuditLog
from audit.related_object_type import RelatedObjectType
from core.dataclasses import AuthorData
from experimentation.constants import (
CONTROL_VARIANT_KEY,
EXPERIMENT_FLAG,
Expand Down Expand Up @@ -50,14 +53,18 @@
srm_p_value,
)
from features.models import FeatureState
from features.versioning.dataclasses import FlagChangeSet, MultivariateValueChangeSet
from features.versioning.versioning_service import update_flag
from integrations.flagsmith.client import get_openfeature_client
from segments.models import Condition, Segment, SegmentRule

if typing.TYPE_CHECKING:
from collections.abc import Sequence
from datetime import datetime

from experimentation.models import Experiment, Metric, WarehouseConnection
from experimentation.types import ExposureGranularity
from features.feature_states.models import FeatureValueType
from organisations.models import Organisation
from users.models import FFAdminUser

Expand Down Expand Up @@ -512,6 +519,86 @@ def transition_experiment_status(
return experiment


def _create_rollout_segment(
experiment: Experiment, rollout_percentage: float
) -> Segment:
segment: Segment = Segment.objects.create(
name=f"experiment-{experiment.id}-rollout",
project=experiment.feature.project,
is_system_segment=True,
)
rule = SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE)
Condition.objects.create(
rule=rule,
operator=PERCENTAGE_SPLIT,
property="$.identity.key",
value=str(rollout_percentage),
)
return segment


def create_experiment_rollout(
experiment: Experiment,
*,
enabled: bool,
rollout_percentage: float,
feature_state_value: str,
value_type: FeatureValueType,
multivariate_values: list[MultivariateValueChangeSet],
author: AuthorData,
) -> None:
segment = _create_rollout_segment(experiment, rollout_percentage)
experiment.rollout_segment = segment
experiment.save()
update_flag(
experiment.environment,
experiment.feature,
FlagChangeSet(
author=author,
enabled=enabled,
feature_state_value=feature_state_value,
type_=value_type,
segment_id=segment.id,
multivariate_values=multivariate_values,
),
)


def update_experiment_rollout(
experiment: Experiment,
*,
enabled: bool,
rollout_percentage: float,
feature_state_value: str,
value_type: FeatureValueType,
multivariate_values: list[MultivariateValueChangeSet],
author: AuthorData,
) -> None:
if experiment.status in (ExperimentStatus.RUNNING, ExperimentStatus.COMPLETED):
raise ValidationError(
f"Cannot update the rollout of a {experiment.status} experiment."
)
segment = experiment.rollout_segment
if segment is None:
raise ValidationError("Experiment has no rollout to update.")

condition = Condition.objects.get(rule__segment=segment, operator=PERCENTAGE_SPLIT)
condition.value = str(rollout_percentage)
condition.save()
update_flag(
experiment.environment,
experiment.feature,
FlagChangeSet(
author=author,
enabled=enabled,
feature_state_value=feature_state_value,
type_=value_type,
segment_id=segment.id,
multivariate_values=multivariate_values,
),
)
Comment thread
gagantrivedi marked this conversation as resolved.


def mark_warehouse_pending_connection(
connection: WarehouseConnection,
) -> WarehouseConnection:
Expand Down
22 changes: 21 additions & 1 deletion api/experimentation/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ExperimentListSerializer,
ExperimentMetricSerializer,
ExperimentResultsSerializer,
ExperimentRolloutSerializer,
ExperimentSerializer,
MetricSerializer,
WarehouseConnectionSerializer,
Expand All @@ -54,11 +55,15 @@
mark_warehouse_pending_connection,
refresh_warehouse_connection_status,
transition_experiment_status,
update_experiment_rollout,
)
from experimentation.tasks import (
compute_experiment_exposures,
compute_experiment_results,
)
from features.feature_states.serializers import (
validate_multivariate_state_values,
)
from users.models import FFAdminUser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -176,7 +181,7 @@ def get_serializer_context(self) -> dict[str, Any]:
return context

def get_serializer_class(self) -> type[BaseSerializer[Experiment]]:
if self.action in ("list", "retrieve", "start", "pause", "complete"):
if self.action in ("list", "retrieve", "start", "pause", "complete", "rollout"):
return ExperimentListSerializer
return ExperimentSerializer

Expand Down Expand Up @@ -290,6 +295,21 @@ def pause(self, request: Request, **kwargs: object) -> Response:
def complete(self, request: Request, **kwargs: object) -> Response:
return self._transition_status(ExperimentStatus.COMPLETED)

@action(detail=True, methods=["patch"])
def rollout(self, request: Request, **kwargs: object) -> Response:
experiment: Experiment = self.get_object()
serializer = ExperimentRolloutSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
data = serializer.validated_data
validate_multivariate_state_values(
experiment.feature, data.get("multivariate_feature_state_values", [])
)
update_experiment_rollout(
experiment,
**ExperimentRolloutSerializer.to_service_kwargs(data, request),
)
return Response(self.get_serializer(experiment).data)

@action(detail=True, methods=["get"])
def exposures(self, request: Request, **kwargs: object) -> Response:
experiment: Experiment = self.get_object()
Expand Down
Loading
Loading