为什么Java程序员需要了解生成器?
Java处理大数据用Stream API,Python处理大数据用生成器(Generator)。
核心思想一样:惰性求值,按需计算。但Python的生成器语法更简洁。
一、迭代器基础
# for循环本质就是迭代器
nums = [1, 2, 3, 4, 5]
it = iter(nums)
while True:
try:
num = next(it)
print(num)
except StopIteration:
break
自定义迭代器
class CountDown:
def __init__(self, start):
self.current = start
def __iter__(self):
return self
def __next__(self):
if self.current <= 0:
raise StopIteration
self.current -= 1
return self.current + 1
for num in CountDown(5):
print(num) # 5, 4, 3, 2, 1
```
## 二、生成器:迭代器的语法糖
```python
def countdown(start):
while start > 0:
yield start # yield代替return
start -= 1
for num in countdown(5):
print(num) # 5, 4, 3, 2, 1
生成器 vs 列表
import sys
列表:一次性加载
squares_list = [x**2 for x in range(10000)]
print(sys.getsizeof(squares_list)) # ~87,624 bytes
生成器:按需生成
squares_gen = (x**2 for x in range(10000))
print(sys.getsizeof(squares_gen)) # ~200 bytes
差距400倍!
处理大文件
def read_large_file(filepath):
with open(filepath, 'r') as f:
for line in f:
yield line.strip()
处理10GB日志,内存恒定
error_count = 0
for line in read_large_file("/var/log/app.log"):
if "ERROR" in line:
error_count += 1
三、生成器表达式
# 列表推导式 -> 生成器表达式([]换())
squares = (x**2 for x in range(1000000))
配合sum,不占内存
total = sum(x**2 for x in range(1000000))
找第一个大于1000的平方
first_big = next(x2 for x in range(100) if x2 > 1000)
1024
四、yield from:委托生成器
def flatten(nested_list):
for item in nested_list:
if isinstance(item, list):
yield from flatten(item)
else:
yield item
nested = [1, [2, 3], [4, [5, 6]], 7]
print(list(flatten(nested)))
[1, 2, 3, 4, 5, 6, 7]
五、itertools库
import itertools
组合
items = ['A', 'B', 'C', 'D']
print(list(itertools.combinations(items, 2)))
分组
data = [('A', 1), ('A', 2), ('B', 3), ('B', 4)]
for key, group in itertools.groupby(data, key=lambda x: x[0]):
print(f"{key}: {list(group)}")
链接
combined = itertools.chain([1, 2], [3, 4], [5, 6])
print(list(combined)) # [1, 2, 3, 4, 5, 6]
六、实战:数据管道
def read_data(filepath):
with open(filepath) as f:
for line in f:
yield line.strip().split(',')
def filter_by_age(records, min_age):
for record in records:
if int(record[1]) >= min_age:
yield record
def format_output(records):
for record in records:
yield f"姓名: {record[0]}, 年龄: {record[1]}"
管道连接(100万条数据,内存几乎为零)
pipeline = format_output(filter_by_age(read_data("users.csv"), 18))
for result in pipeline:
print(result)
总结
| 特性 | Python生成器 | Java Stream |
|---|---|---|
| 嵌套处理 | yield from | flatMap |
本系列持续更新中,关注不迷路。