Skip to content
200 changes: 156 additions & 44 deletions src/agents/sandbox/sandboxes/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,72 @@

logger = logging.getLogger(__name__)


# Non-seekable payloads are spooled to measure their length; keep small ones in
# RAM and spill larger ones to a temp file so a big upload can't OOM the process.
_STREAM_SPOOL_MAX_SIZE = 16 * 1024 * 1024


def _measure_stream(stream: io.IOBase) -> tuple[int, io.IOBase, io.IOBase | None]:
"""Return ``(length, readable_stream, spool_to_close)`` for a length-framed write.

Seekable streams are measured in place (and rewound); ``spool_to_close`` is
``None``. Non-seekable streams (e.g. an HTTP response body or pipe) are copied
into a ``SpooledTemporaryFile`` — kept in memory up to
``_STREAM_SPOOL_MAX_SIZE``, spilled to disk beyond it — so the byte length can
be determined without buffering the whole payload in RAM; the spool is returned
so the caller can close it.

Callers run this on the executor thread, never the event loop.
"""
try:
start = stream.tell()
stream.seek(0, io.SEEK_END)
end = stream.tell()
stream.seek(start)
# Clamp to 0: a stream positioned past its end has no readable bytes, and
# a negative count would become `head -c -N` ("all but the last N bytes"),
# which reads to EOF and re-hangs over a TLS stdin.
return max(0, end - start), stream, None
except (AttributeError, OSError, ValueError):
spool: Any = tempfile.SpooledTemporaryFile(max_size=_STREAM_SPOOL_MAX_SIZE)
try:
length = 0
while True:
chunk = stream.read(1024 * 1024)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
length += len(chunk)
spool.write(chunk)
spool.seek(0)
return length, spool, spool
except BaseException:
# The caller only closes the spool once it is returned; on any error
# here it never receives it, so close it now to avoid a leaked temp
# file / buffer.
spool.close()
raise


# POSIX sh that pipes exactly ``<n>`` bytes into the real command (``"$@"``).
# ``head -c`` bounds the read so completion never depends on a stdin half-close
# (unreliable over a TLS DOCKER_HOST). A bare ``head -c "$n" | "$@"`` pipeline
# reports only the *consumer's* status, so if ``head`` can't produce the bytes —
# missing entirely, or a POSIX-only ``head`` that rejects ``-c`` (POSIX specifies
# only ``-n``) — the consumer (``cat``/``tar``) would see an empty pipe, exit 0,
# and the write would "succeed" after creating/truncating an empty file. Preflight
# ``head -c`` on known input and bail out (exit 98) unless it yields the expected
# byte, so such writes surface as errors instead of silent data loss. The check
# needs no writable path (avoiding a predictable /tmp status file that untrusted
# container code could pre-seed as a symlink for the root exec to follow) and no
# ``pipefail`` (not POSIX; dash lacks it).
_LENGTH_FRAMED_STDIN_SCRIPT = (
'n=$1; shift; [ "$(printf ab | head -c 1 2>/dev/null)" = a ] || exit 98; head -c "$n" | "$@"'
)


_PREPARE_USER_PTY_PID_SCRIPT = (
'pid_path="$1"\n'
'pid_user="$2"\n'
Expand Down Expand Up @@ -562,57 +628,103 @@ async def _stream_into_exec(
error_path: Path,
user: str | User | None = None,
) -> None:
# Frame the payload by length so the in-container reader terminates on a
# byte count rather than a stdin half-close. Docker's exec-attach stream
# does not carry a reliable stdin EOF over a TLS DOCKER_HOST: the
# ``shutdown(SHUT_WR)`` below is silently swallowed, so ``tar -x`` / ``cat``
# would block forever waiting for input that never ends (observed against
# Docker-in-Docker sidecars and remote daemons reached over TLS). Piping
# the real command through ``head -c <n>`` makes it stop after exactly
# ``<n>`` bytes, independent of transport, and keeps the deliberate
# avoidance of ``put_archive()`` (see ``write``) intact.
def _write() -> int | None:
container_client = self._container.client
assert container_client is not None
api = container_client.api
resp = api.exec_create(
self._container.id,
cmd,
stdin=True,
stdout=True,
stderr=True,
workdir=None,
user=self._coerce_exec_user(user) or "",
)
exec_socket = self._start_exec_socket(api=api, exec_id=cast(str, resp["Id"]))
sock = exec_socket.sock
raw_sock = exec_socket.raw_sock
try:
while True:
chunk = stream.read(1024 * 1024)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
elif not isinstance(chunk, bytes):
chunk = bytes(chunk)
if hasattr(raw_sock, "sendall"):
raw_sock.sendall(chunk)
else:
cast(Any, sock).write(chunk)

try:
if hasattr(raw_sock, "shutdown"):
raw_sock.shutdown(socket.SHUT_WR)
else:
cast(Any, sock).flush()
except Exception:
pass

# Measure/spool on this executor thread (never the event loop). A
# non-seekable stream is spooled to a SpooledTemporaryFile (bounded
# memory, then disk) rather than read whole into RAM.
payload_length, read_stream, spool = _measure_stream(stream)
try:
framed_cmd = [
"sh",
"-c",
_LENGTH_FRAMED_STDIN_SCRIPT,
"sh",
str(payload_length),
*cmd,
]
resp = api.exec_create(
self._container.id,
framed_cmd,
stdin=True,
stdout=True,
stderr=True,
workdir=None,
user=self._coerce_exec_user(user) or "",
)
exec_socket = self._start_exec_socket(api=api, exec_id=cast(str, resp["Id"]))
sock = exec_socket.sock
raw_sock = exec_socket.raw_sock
try:
if hasattr(raw_sock, "recv"):
while raw_sock.recv(1024 * 1024):
pass
else:
while cast(Any, sock).read(1024 * 1024):
pass
except Exception:
pass
# Send exactly ``payload_length`` bytes — the count the exec
# was framed with (``head -c "$n"``). Reading to EOF instead
# would desync if the stream changed after _measure_stream:
# extra bytes would pile up behind a ``head`` that already
# stopped, and a short read would leave ``head`` blocked on a
# TLS stdin that never EOFs (the original hang). If the stream
# ends early we fail loudly rather than truncate silently.
remaining = payload_length
while remaining > 0:
chunk = read_stream.read(min(1024 * 1024, remaining))
if not chunk:
raise WorkspaceArchiveWriteError(
path=error_path,
context={
"reason": "stream_shorter_than_measured",
"expected": str(payload_length),
"sent": str(payload_length - remaining),
},
)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
elif not isinstance(chunk, bytes):
chunk = bytes(chunk)
if len(chunk) > remaining:
# Only reachable for multibyte text streams (never the
# byte streams these writes use); cap to the framed count.
chunk = chunk[:remaining]
if hasattr(raw_sock, "sendall"):
raw_sock.sendall(chunk)
else:
cast(Any, sock).write(chunk)
remaining -= len(chunk)

try:
if hasattr(raw_sock, "shutdown"):
raw_sock.shutdown(socket.SHUT_WR)
else:
cast(Any, sock).flush()
except Exception:
pass

try:
if hasattr(raw_sock, "recv"):
while raw_sock.recv(1024 * 1024):
pass
else:
while cast(Any, sock).read(1024 * 1024):
pass
except Exception:
pass
finally:
exec_socket.close()

return cast(int | None, api.exec_inspect(resp["Id"]).get("ExitCode"))
finally:
exec_socket.close()

return cast(int | None, api.exec_inspect(resp["Id"]).get("ExitCode"))
if spool is not None:
spool.close()

loop = asyncio.get_running_loop()
try:
Expand Down
Loading