Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions cloud_pipelines_backend/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def handle_permission_error(request: fastapi.Request, exc: errors.PermissionErro
content={"message": str(exc)},
)

@app.exception_handler(errors.ItemAlreadyExistsError)
def handle_item_already_exists_error(
request: fastapi.Request, exc: errors.ItemAlreadyExistsError
):
return fastapi.responses.JSONResponse(
status_code=409,
content={"message": str(exc)},
)

get_user_details_dependency = fastapi.Depends(user_details_getter)

def get_user_name(
Expand Down Expand Up @@ -390,6 +399,42 @@ def get_current_user(
permissions=permissions,
)

### Secrets routes
secrets_service = api_server_sql.SecretsApiService()

router.get("/api/secrets/", tags=["secrets"], **default_config)(
inject_session_dependency(
inject_user_name(secrets_service.list_secrets, parameter_name="user_id")
)
)
router.post("/api/secrets/", tags=["secrets"], **default_config)(
add_parameter_annotation_metadata(
inject_session_dependency(
inject_user_name(
secrets_service.create_secret, parameter_name="user_id"
)
),
parameter_name="secret_value",
annotation_metadata=fastapi.Body(embed=True),
)
)
router.put("/api/secrets/{secret_name}", tags=["secrets"], **default_config)(
add_parameter_annotation_metadata(
inject_session_dependency(
inject_user_name(
secrets_service.update_secret, parameter_name="user_id"
)
),
parameter_name="secret_value",
annotation_metadata=fastapi.Body(embed=True),
)
)
router.delete("/api/secrets/{secret_name}", tags=["secrets"], **default_config)(
inject_session_dependency(
inject_user_name(secrets_service.delete_secret, parameter_name="user_id")
)
)

### Component library routes

component_service = components_api.ComponentService()
Expand Down
194 changes: 184 additions & 10 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,147 @@ def get_signed_artifact_url(
return GetArtifactSignedUrlResponse(signed_url=signed_url)


# === Secrets Service
@dataclasses.dataclass(kw_only=True)
class SecretInfoResponse:
secret_name: str
created_at: datetime.datetime
updated_at: datetime.datetime
expires_at: datetime.datetime | None = None
description: str | None = None

@classmethod
def from_db(cls, secret_row: bts.Secret) -> "SecretInfoResponse":
return SecretInfoResponse(
secret_name=secret_row.secret_name,
created_at=secret_row.created_at,
updated_at=secret_row.updated_at,
expires_at=secret_row.expires_at,
description=secret_row.description,
)


@dataclasses.dataclass(kw_only=True)
class ListSecretsResponse:
secrets: list[SecretInfoResponse]


class SecretsApiService:

def create_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_name: str,
secret_value: str,
description: str | None = None,
expires_at: datetime.datetime | None = None,
) -> SecretInfoResponse:
secret_name = secret_name.strip()
if not secret_name:
raise ApiServiceError(f"Secret name must not be empty.")
return self._create_or_update_secret(
session=session,
user_id=user_id,
secret_name=secret_name,
secret_value=secret_value,
description=description,
expires_at=expires_at,
raise_if_exists=True,
)

def update_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_name: str,
secret_value: str,
description: str | None = None,
expires_at: datetime.datetime | None = None,
) -> SecretInfoResponse:
return self._create_or_update_secret(
session=session,
user_id=user_id,
secret_name=secret_name,
secret_value=secret_value,
description=description,
expires_at=expires_at,
raise_if_not_exists=True,
)

def _create_or_update_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_name: str,
secret_value: str,
description: str | None = None,
expires_at: datetime.datetime | None = None,
raise_if_not_exists: bool = False,
raise_if_exists: bool = False,
) -> SecretInfoResponse:
current_time = _get_current_time()
secret = session.get(bts.Secret, (user_id, secret_name))
if secret:
if raise_if_exists:
raise errors.ItemAlreadyExistsError(
f"Secret with name '{secret_name}' already exists."
)
secret.secret_value = secret_value
secret.updated_at = current_time
else:
if raise_if_not_exists:
raise errors.ItemNotFoundError(
f"Secret with name '{secret_name}' does not exist."
)
secret = bts.Secret(
user_id=user_id,
secret_name=secret_name,
secret_value=secret_value,
created_at=current_time,
updated_at=current_time,
)
session.add(secret)
if description:
secret.description = description
if expires_at:
secret.expires_at = expires_at
response = SecretInfoResponse.from_db(secret)
session.commit()
return response

def delete_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_name: str,
) -> None:
secret = session.get(bts.Secret, (user_id, secret_name))
if not secret:
raise errors.ItemNotFoundError(
f"Secret with name '{secret_name}' does not exist."
)
session.delete(secret)
session.commit()

def list_secrets(
self,
*,
session: orm.Session,
user_id: str,
) -> ListSecretsResponse:
secrets = session.scalars(
sql.select(bts.Secret).where(bts.Secret.user_id == user_id)
).all()
return ListSecretsResponse(
secrets=[SecretInfoResponse.from_db(secret) for secret in secrets]
)


# ============

# Idea for how to add deep nested graph:
Expand All @@ -1005,11 +1146,16 @@ def get_signed_artifact_url(
# No. Decided to first do topological sort and then 1-stage generation.


_ArtifactNodeOrDynamicDataType = typing.Union[
bts.ArtifactNode, structures.DynamicDataArgument
]


def _recursively_create_all_executions_and_artifacts_root(
session: orm.Session,
root_task_spec: structures.TaskSpec,
) -> bts.ExecutionNode:
input_artifact_nodes: dict[str, bts.ArtifactNode] = {}
input_artifact_nodes: dict[str, _ArtifactNodeOrDynamicDataType] = {}

root_component_spec = root_task_spec.component_ref.spec
if not root_component_spec:
Expand All @@ -1035,12 +1181,8 @@ def _recursively_create_all_executions_and_artifacts_root(
raise ApiServiceError(
f"root task arguments can only be constants, but got {input_name}={input_argument}. {root_task_spec=}"
)
elif not isinstance(input_argument, str):
raise ApiServiceError(
f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}"
)
# TODO: Support constant input artifacts (artifact IDs)
if input_argument is not None:
elif isinstance(input_argument, str):
input_artifact_nodes[input_name] = (
# _construct_constant_artifact_node_and_add_to_session(
# session=session, value=input_argument, artifact_type=input_spec.type
Expand All @@ -1052,6 +1194,12 @@ def _recursively_create_all_executions_and_artifacts_root(
# This constant artifact won't be added to the DB
# TODO: Actually, they will be added...
# We don't need to link this input artifact here. It will be handled downstream.
elif isinstance(input_argument, structures.DynamicDataArgument):
input_artifact_nodes[input_name] = input_argument
else:
raise ApiServiceError(
f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}"
)

root_execution_node = _recursively_create_all_executions_and_artifacts(
session=session,
Expand All @@ -1065,7 +1213,7 @@ def _recursively_create_all_executions_and_artifacts_root(
def _recursively_create_all_executions_and_artifacts(
session: orm.Session,
root_task_spec: structures.TaskSpec,
input_artifact_nodes: dict[str, bts.ArtifactNode],
input_artifact_nodes: dict[str, _ArtifactNodeOrDynamicDataType],
ancestors: list[bts.ExecutionNode],
) -> bts.ExecutionNode:
root_component_spec = root_task_spec.component_ref.spec
Expand Down Expand Up @@ -1098,6 +1246,26 @@ def _recursively_create_all_executions_and_artifacts(
input_artifact_nodes = dict(input_artifact_nodes)
for input_spec in root_component_spec.inputs or []:
input_artifact_node = input_artifact_nodes.get(input_spec.name)
if isinstance(input_artifact_node, structures.DynamicDataArgument):
if not (
isinstance(input_artifact_node.dynamic_data, str)
or (
isinstance(input_artifact_node.dynamic_data, dict)
and len(input_artifact_node.dynamic_data) == 1
)
):
raise ApiServiceError(
f"Dynamic data argument must be a string or a dict with a single key set, but got {input_artifact_node.dynamic_data}"
)
# Storing the dynamic data arguments for later use by the orchestrator.
extra_data = root_execution_node.extra_data or {}
extra_data.setdefault(
bts.EXECUTION_NODE_EXTRA_DATA_DYNAMIC_DATA_ARGUMENTS_KEY, {}
)[input_spec.name] = input_artifact_node.dynamic_data

root_execution_node.extra_data = extra_data
# Not adding any artifact link for secret inputs
continue
if input_artifact_node is None and not input_spec.optional:
if input_spec.default:
input_artifact_node = (
Expand Down Expand Up @@ -1163,7 +1331,8 @@ def _recursively_create_all_executions_and_artifacts(
root_execution_node.container_execution_status = (
bts.ContainerExecutionStatus.QUEUED
if all(
artifact_node.artifact_data
not isinstance(artifact_node, bts.ArtifactNode)
or artifact_node.artifact_data
for artifact_node in input_artifact_nodes.values()
)
else bts.ContainerExecutionStatus.WAITING_FOR_UPSTREAM
Expand All @@ -1190,10 +1359,12 @@ def _recursively_create_all_executions_and_artifacts(
raise ApiServiceError(
f"child_task_spec.component_ref.spec is empty. {child_task_spec=}"
)
child_task_input_artifact_nodes: dict[str, bts.ArtifactNode] = {}
child_task_input_artifact_nodes: dict[
str, _ArtifactNodeOrDynamicDataType
] = {}
for input_spec in child_component_spec.inputs or []:
input_argument = (child_task_spec.arguments or {}).get(input_spec.name)
input_artifact_node: bts.ArtifactNode | None = None
input_artifact_node: _ArtifactNodeOrDynamicDataType | None = None
if input_argument is None and not input_spec.optional:
# Not failing on unconnected required input if there is a default value
if input_spec.default is None:
Expand Down Expand Up @@ -1233,6 +1404,9 @@ def _recursively_create_all_executions_and_artifacts(
# artifact_type=input_spec.type,
# )
# )
elif isinstance(input_argument, structures.DynamicDataArgument):
# We'll deal with dynamic data (e.g. secrets) when launching the container.
input_artifact_node = input_argument
else:
raise ApiServiceError(
f"Unexpected task argument: {input_spec.name}={input_argument}. {child_task_spec=}"
Expand Down
13 changes: 13 additions & 0 deletions cloud_pipelines_backend/backend_types_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class ExecutionNode(_TableBase):
EXECUTION_NODE_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = (
"orchestration_error_message"
)
EXECUTION_NODE_EXTRA_DATA_DYNAMIC_DATA_ARGUMENTS_KEY = "dynamic_data_arguments"
CONTAINER_EXECUTION_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = (
"orchestration_error_message"
)
Expand Down Expand Up @@ -476,3 +477,15 @@ class PipelineRunAnnotation(_TableBase):
pipeline_run: orm.Mapped[PipelineRun] = orm.relationship(repr=False, init=False)
key: orm.Mapped[str] = orm.mapped_column(default=None, primary_key=True)
value: orm.Mapped[str | None] = orm.mapped_column(default=None)


class Secret(_TableBase):
__tablename__ = "secret"
user_id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True)
secret_name: orm.Mapped[str] = orm.mapped_column(primary_key=True)
secret_value: orm.Mapped[str]
created_at: orm.Mapped[datetime.datetime]
updated_at: orm.Mapped[datetime.datetime]
expires_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None)
description: orm.Mapped[str | None] = orm.mapped_column(default=None)
extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None)
21 changes: 20 additions & 1 deletion cloud_pipelines_backend/component_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,26 @@ class TaskOutputArgument(_BaseModel): # Has additional constructor for convenie
task_output: TaskOutputReference


ArgumentType = Union[PrimitiveTypes, GraphInputArgument, TaskOutputArgument]
DynamicDataReference = str | dict[str, Any]


@dataclasses.dataclass
class DynamicDataArgument(_BaseModel):
"""Argument that references data that's dynamically produced by the execution system at runtime.

Examples of dynamic data:
* Secret value
* Container execution ID
* Pipeline run ID
* Loop index/item
"""

dynamic_data: DynamicDataReference


ArgumentType = Union[
PrimitiveTypes, GraphInputArgument, TaskOutputArgument, DynamicDataArgument
]


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions cloud_pipelines_backend/launchers/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class InputArgument:
value: str | None = None
uri: str | None = None
staging_uri: str
is_secret: bool = False


class ContainerTaskLauncher(typing.Generic[_TLaunchedContainer], abc.ABC):
Expand Down
Loading