Skip to content

Commit 13fc32b

Browse files
committed
Wrap consumer.poll() for KafkaConsumer iteration
1 parent 98ebff8 commit 13fc32b

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

kafka/consumer/fetcher.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")):
292292
raise Errors.KafkaTimeoutError(
293293
"Failed to get offsets by timestamps in %s ms" % (timeout_ms,))
294294

295-
def fetched_records(self, max_records=None):
295+
def fetched_records(self, max_records=None, update_offsets=True):
296296
"""Returns previously fetched records and updates consumed offsets.
297297
298298
Arguments:
@@ -330,10 +330,11 @@ def fetched_records(self, max_records=None):
330330
else:
331331
records_remaining -= self._append(drained,
332332
self._next_partition_records,
333-
records_remaining)
333+
records_remaining,
334+
update_offsets)
334335
return dict(drained), bool(self._completed_fetches)
335336

336-
def _append(self, drained, part, max_records):
337+
def _append(self, drained, part, max_records, update_offsets):
337338
if not part:
338339
return 0
339340

@@ -366,7 +367,8 @@ def _append(self, drained, part, max_records):
366367
for record in part_records:
367368
drained[tp].append(record)
368369

369-
self._subscriptions.assignment[tp].position = next_offset
370+
if update_offsets:
371+
self._subscriptions.assignment[tp].position = next_offset
370372
return len(part_records)
371373

372374
else:

kafka/consumer/group.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ class KafkaConsumer(six.Iterator):
302302
'sasl_plain_password': None,
303303
'sasl_kerberos_service_name': 'kafka',
304304
'sasl_kerberos_domain_name': None,
305-
'sasl_oauth_token_provider': None
305+
'sasl_oauth_token_provider': None,
306+
'legacy_iterator': True, # experimental feature
306307
}
307308
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000
308309

@@ -660,7 +661,7 @@ def _poll_once(self, timeout_ms, max_records):
660661

661662
# If data is available already, e.g. from a previous network client
662663
# poll() call to commit, then just return it immediately
663-
records, partial = self._fetcher.fetched_records(max_records)
664+
records, partial = self._fetcher.fetched_records(max_records, update_offsets=bool(self._iterator))
664665
if records:
665666
# Before returning the fetched records, we can send off the
666667
# next round of fetches and avoid block waiting for their
@@ -680,7 +681,7 @@ def _poll_once(self, timeout_ms, max_records):
680681
if self._coordinator.need_rejoin():
681682
return {}
682683

683-
records, _ = self._fetcher.fetched_records(max_records)
684+
records, _ = self._fetcher.fetched_records(max_records, update_offsets=bool(self._iterator))
684685
return records
685686

686687
def position(self, partition):
@@ -743,6 +744,9 @@ def pause(self, *partitions):
743744
for partition in partitions:
744745
log.debug("Pausing partition %s", partition)
745746
self._subscription.pause(partition)
747+
# Because the iterator checks is_fetchable() on each iteration
748+
# we expect pauses to get handled automatically and therefore
749+
# we do not need to reset the full iterator (forcing a full refetch)
746750

747751
def paused(self):
748752
"""Get the partitions that were previously paused using
@@ -790,6 +794,8 @@ def seek(self, partition, offset):
790794
assert partition in self._subscription.assigned_partitions(), 'Unassigned partition'
791795
log.debug("Seeking to offset %s for partition %s", offset, partition)
792796
self._subscription.assignment[partition].seek(offset)
797+
if not self.config['legacy_iterator']:
798+
self._iterator = None
793799

794800
def seek_to_beginning(self, *partitions):
795801
"""Seek to the oldest available offset for partitions.
@@ -814,6 +820,8 @@ def seek_to_beginning(self, *partitions):
814820
for tp in partitions:
815821
log.debug("Seeking to beginning of partition %s", tp)
816822
self._subscription.need_offset_reset(tp, OffsetResetStrategy.EARLIEST)
823+
if not self.config['legacy_iterator']:
824+
self._iterator = None
817825

818826
def seek_to_end(self, *partitions):
819827
"""Seek to the most recent available offset for partitions.
@@ -838,6 +846,8 @@ def seek_to_end(self, *partitions):
838846
for tp in partitions:
839847
log.debug("Seeking to end of partition %s", tp)
840848
self._subscription.need_offset_reset(tp, OffsetResetStrategy.LATEST)
849+
if not self.config['legacy_iterator']:
850+
self._iterator = None
841851

842852
def subscribe(self, topics=(), pattern=None, listener=None):
843853
"""Subscribe to a list of topics, or a topic regex pattern.
@@ -913,6 +923,8 @@ def unsubscribe(self):
913923
self._client.cluster.need_all_topic_metadata = False
914924
self._client.set_topics([])
915925
log.debug("Unsubscribed all topics or patterns and assigned partitions")
926+
if not self.config['legacy_iterator']:
927+
self._iterator = None
916928

917929
def metrics(self, raw=False):
918930
"""Get metrics on consumer performance.
@@ -1075,6 +1087,25 @@ def _update_fetch_positions(self, partitions):
10751087
# Then, do any offset lookups in case some positions are not known
10761088
self._fetcher.update_fetch_positions(partitions)
10771089

1090+
def _message_generator_v2(self):
1091+
timeout_ms = 1000 * (self._consumer_timeout - time.time())
1092+
record_map = self.poll(timeout_ms=timeout_ms)
1093+
for tp, records in six.iteritems(record_map):
1094+
# Generators are stateful, and it is possible that the tp / records
1095+
# here may become stale during iteration -- i.e., we seek to a
1096+
# different offset, pause consumption, or lose assignment.
1097+
for record in records:
1098+
# is_fetchable(tp) should handle assignment changes and offset
1099+
# resets; for all other changes (e.g., seeks) we'll rely on the
1100+
# outer function destroying the existing iterator/generator
1101+
# via self._iterator = None
1102+
if not self._subscription.is_fetchable(tp):
1103+
log.debug("Not returning fetched records for partition %s"
1104+
" since it is no longer fetchable", tp)
1105+
break
1106+
self._subscription.assignment[tp].position = record.offset + 1
1107+
yield record
1108+
10781109
def _message_generator(self):
10791110
assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment'
10801111
while time.time() < self._consumer_timeout:
@@ -1127,6 +1158,26 @@ def __iter__(self): # pylint: disable=non-iterator-returned
11271158
return self
11281159

11291160
def __next__(self):
1161+
# Now that the heartbeat thread runs in the background
1162+
# there should be no reason to maintain a separate iterator
1163+
# but we'll keep it available for a few releases just in case
1164+
if self.config['legacy_iterator']:
1165+
return self.next_v1()
1166+
else:
1167+
return self.next_v2()
1168+
1169+
def next_v2(self):
1170+
self._set_consumer_timeout()
1171+
while time.time() < self._consumer_timeout:
1172+
if not self._iterator:
1173+
self._iterator = self._message_generator_v2()
1174+
try:
1175+
return next(self._iterator)
1176+
except StopIteration:
1177+
self._iterator = None
1178+
raise StopIteration()
1179+
1180+
def next_v1(self):
11301181
if not self._iterator:
11311182
self._iterator = self._message_generator()
11321183

0 commit comments

Comments
 (0)