Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 86 additions & 22 deletions requests_toolbelt/multipart/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,36 @@ def __init__(self, fields, boundary=None, encoding='utf-8'):
# Our buffer
self._buffer = CustomBytesIO(encoding=encoding)

# Number of bytes read from the encoder
self._bytes_read = 0

# Pre-compute each part's headers
self._prepare_parts()

# Load boundary into buffer
self._write_boundary()

def __iter__(self):
# Need to implement iterator protocol otherwise requests won't set
# `is_stream` to `True` and won't rewind the body on redirects.
return self

def __next__(self):
if self.finished:
raise StopIteration()
return self.read(8192)

def _reset(self):
"""Reset the encoder to the beginning."""
self.finished = False
for part in self.parts:
part.reset()
self._iter_parts = iter(self.parts)
self._current_part = None
self._buffer = CustomBytesIO(encoding=self.encoding)
self._bytes_read = 0
self._write_boundary()

@property
def len(self):
"""Length of the multipart/form-data body.
Expand Down Expand Up @@ -186,7 +210,7 @@ def _calculate_load_amount(self, read_size):

def _load(self, amount):
"""Load ``amount`` number of bytes into the buffer."""
self._buffer.smart_truncate()
smart_truncate(self._buffer)
part = self._current_part or self._next_part()
while amount == -1 or amount > 0:
written = 0
Expand Down Expand Up @@ -304,15 +328,30 @@ def read(self, size=-1):
remaining bytes.
:returns: bytes
"""
if self.finished:
return self._buffer.read(size)
if not self.finished:
bytes_to_load = size
if bytes_to_load != -1 and bytes_to_load is not None:
bytes_to_load = self._calculate_load_amount(int(size))

self._load(bytes_to_load)
string = self._buffer.read(size)
self._bytes_read += len(string)
return string

bytes_to_load = size
if bytes_to_load != -1 and bytes_to_load is not None:
bytes_to_load = self._calculate_load_amount(int(size))
def tell(self):
# type: () -> int
return self._bytes_read

self._load(bytes_to_load)
return self._buffer.read(size)
def seek(self, offset, whence=0):
# type: (int, int) -> int
if (offset, whence) == (0, 0):
self._reset()
elif (offset, whence) == (0, self._bytes_read) or (offset, whence) == (0, 1):
pass
else:
raise io.UnsupportedOperation(
"MultipartEncoder only supports seeking to the beginning")
return self.tell()


def IDENTITY(monitor):
Expand Down Expand Up @@ -377,10 +416,6 @@ def __init__(self, encoder, callback=None):
#: Optionally function to call after a read
self.callback = callback or IDENTITY

#: Number of bytes already read from the :class:`MultipartEncoder`
#: instance
self.bytes_read = 0

#: Avoid the same problem in bug #80
self.len = self.encoder.len

Expand All @@ -394,12 +429,16 @@ def from_fields(cls, fields, boundary=None, encoding='utf-8',
def content_type(self):
return self.encoder.content_type

@property
def bytes_read(self):
"""Number of bytes already read from the :class:`MultipartEncoder` instance."""
return self.encoder._bytes_read

def to_string(self):
return self.read()

def read(self, size=-1):
string = self.encoder.read(size)
self.bytes_read += len(string)
self.callback(self)
return string

Expand Down Expand Up @@ -486,6 +525,10 @@ def __init__(self, headers, body):
self.body = body
self.headers_unread = True
self.len = len(self.headers) + total_len(self.body)
try:
self.initial_pos = body.tell()
except (AttributeError, OSError, NotImplementedError):
self.initial_pos = None

@classmethod
def from_field(cls, field, encoding):
Expand Down Expand Up @@ -529,6 +572,20 @@ def write_to(self, buffer, size):

return written

def reset(self):
"""Reset the part to the beginning."""
if self.headers_unread:
return
if self.initial_pos is None:
raise io.UnsupportedOperation(
"Underlying body object does not support tell(). Cannot reset.")
try:
self.body.seek(self.initial_pos)
except AttributeError as e:
raise io.UnsupportedOperation(
"Underlying body object does not support seek(). Cannot reset.") from e
self.headers_unread = True


class CustomBytesIO(io.BytesIO):
def __init__(self, buffer=None, encoding='utf-8'):
Expand All @@ -552,16 +609,17 @@ def append(self, bytes):
written = self.write(bytes)
return written

def smart_truncate(self):
to_be_read = total_len(self)
already_read = self._get_end() - to_be_read

if already_read >= to_be_read:
old_bytes = self.read()
self.seek(0, 0)
self.truncate()
self.write(old_bytes)
self.seek(0, 0) # We want to be at the beginning
def smart_truncate(buf):
to_be_read = buf.len
already_read = buf.tell()

if already_read >= to_be_read:
old_bytes = buf.read()
buf.seek(0, 0)
buf.truncate()
buf.write(old_bytes)
buf.seek(0, 0) # We want to be at the beginning


class FileWrapper(object):
Expand All @@ -575,6 +633,12 @@ def len(self):
def read(self, length=-1):
return self.fd.read(length)

def tell(self):
return self.fd.tell()

def seek(self, offset, whence=0):
return self.fd.seek(offset, whence)


class FileFromURLWrapper(object):
"""File from URL wrapper.
Expand Down
78 changes: 67 additions & 11 deletions tests/test_multipart_encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# -*- coding: utf-8 -*-
import unittest
import io
import os
import tempfile

import requests

import pytest
from requests_toolbelt.multipart.encoder import (
CustomBytesIO, MultipartEncoder, FileFromURLWrapper, FileNotSupportedError)
CustomBytesIO, MultipartEncoder, FileFromURLWrapper, FileNotSupportedError,
smart_truncate)
from requests_toolbelt._compat import filepost
from . import get_betamax

Expand Down Expand Up @@ -78,7 +81,7 @@ def test_truncates_intelligently(self):
self.instance.write(b'abcdefghijklmnopqrstuvwxyzabcd') # 30 bytes
assert self.instance.tell() == 30
self.instance.seek(-10, 2)
self.instance.smart_truncate()
smart_truncate(self.instance)
assert self.instance.len == 10
assert self.instance.read() == b'uvwxyzabcd'
assert self.instance.tell() == 10
Expand Down Expand Up @@ -128,22 +131,25 @@ def test_no_content_length_header(self):
)


EXPECTED = (
b'--this-is-a-boundary\r\n'
b'Content-Disposition: form-data; name="field"\r\n\r\n'
b'value\r\n'
b'--this-is-a-boundary\r\n'
b'Content-Disposition: form-data; name="other_field"\r\n\r\n'
b'other_value\r\n'
b'--this-is-a-boundary--\r\n'
)


class TestMultipartEncoder(unittest.TestCase):
def setUp(self):
self.parts = [('field', 'value'), ('other_field', 'other_value')]
self.boundary = 'this-is-a-boundary'
self.instance = MultipartEncoder(self.parts, boundary=self.boundary)

def test_to_string(self):
assert self.instance.to_string() == (
'--this-is-a-boundary\r\n'
'Content-Disposition: form-data; name="field"\r\n\r\n'
'value\r\n'
'--this-is-a-boundary\r\n'
'Content-Disposition: form-data; name="other_field"\r\n\r\n'
'other_value\r\n'
'--this-is-a-boundary--\r\n'
).encode()
assert self.instance.to_string() == EXPECTED

def test_content_type(self):
expected = 'multipart/form-data; boundary=this-is-a-boundary'
Expand Down Expand Up @@ -202,6 +208,9 @@ def test_reads_file_from_url_wrapper(self):
[('field', 'foo'), ('file', FileFromURLWrapper(url, session=s))])
assert m.read() is not None

with pytest.raises(OSError):
m.seek(0, 0)

def test_reads_open_file_objects_with_a_specified_filename(self):
with open('setup.py', 'rb') as fd:
m = MultipartEncoder(
Expand Down Expand Up @@ -319,5 +328,52 @@ def test_no_parts(self):
output = m.read().decode('utf-8')
assert output == '----90967316f8404798963cce746a4f4ef9--\r\n'

def test_seeking(self):
field_data = self.parts[0][1].encode('utf-8')

tmpfile = tempfile.TemporaryFile()
tmpfile.write(field_data)
tmpfile.seek(0)
parts = self.parts.copy()
parts[0] = (self.parts[0][0], tmpfile)
m = MultipartEncoder(parts, boundary=self.boundary)

tmpfile = tempfile.TemporaryFile()
gunk = b"Some gunk at the beginning"
tmpfile.write(gunk)
tmpfile.write(field_data)
tmpfile.seek(len(gunk))
parts = self.parts.copy()
parts[0] = (self.parts[0][0], tmpfile)
m2 = MultipartEncoder(parts, boundary=self.boundary)

for instance in (self.instance, m, m2):
assert instance.tell() == 0
assert instance.read() == EXPECTED
# Exhausted:
assert instance.read() == b''
assert instance.seek(0) == 0
assert instance.read() == EXPECTED

def test_redirect(self):
"""Verifies integration with requests."""
tmpfile = tempfile.TemporaryFile()
tmpfile.write(b'from-file')
tmpfile.seek(0)

m = MultipartEncoder([('field', 'foo'), ('myfile', tmpfile)])
# Can't use betamax here - it responds too quickly and requests doesn't
# have time to start reading from the MultipartEncoder before the
# redirect response is returned - so the seek never happens.
resp = requests.post(
'https://httpbin.org/redirect-to?status_code=307&url=/post',
data=m, headers={'Content-Type': m.content_type},
timeout=10)
resp.raise_for_status()
print(resp.json())
assert resp.json()['form']['myfile'] == 'from-file'
assert resp.json()['form']['field'] == 'foo'


if __name__ == '__main__':
unittest.main()