元类、魔法方法、装饰器、混入类、协程
1. 元类(Metaclass)进阶
1.1 动态属性控制
class AutoPropertyMeta(type):
"""自动为类添加属性访问控制"""
def __new__(mcs, name, bases, dct):
# 自动为所有大写属性创建getter/setter
for key in list(dct.keys()):
if key.isupper():
private_key = f"_{key.lower()}"
# 创建属性描述符
def getter(self, attr=private_key):
return getattr(self, attr)
def setter(self, value, attr=private_key):
setattr(self, attr, value)
dct[key] = property(getter, setter)
dct[private_key] = None # 初始化私有属性
return super().__new__(mcs, name, bases, dct)
class Config(metaclass=AutoPropertyMeta):
DEBUG = True
SECRET_KEY = "default"
config = Config()
print(config.DEBUG) # True
config.SECRET_KEY = "new_secret"
print(config._secret_key) # "new_secret"1.2 类注册系统
class PluginRegistryMeta(type):
"""自动注册所有子类"""
registry = {}
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
# 忽略抽象基类
if not getattr(cls, '__abstract__', False):
PluginRegistryMeta.registry[name] = cls
class PluginBase(metaclass=PluginRegistryMeta):
__abstract__ = True
def execute(self):
raise NotImplementedError
class EmailPlugin(PluginBase):
def execute(self):
return "Sending email..."
class SMSPlugin(PluginBase):
def execute(self):
return "Sending SMS..."
# 自动注册的插件
print(PluginRegistryMeta.registry)
# {'EmailPlugin': <class '__main__.EmailPlugin'>,
# 'SMSPlugin': <class '__main__.SMSPlugin'>}1.3 ORM 高级实现
class Field:
def __init__(self, field_type, primary_key=False):
self.field_type = field_type
self.primary_key = primary_key
class ModelMeta(type):
"""ORM 元类实现"""
def __new__(mcs, name, bases, attrs):
# 收集字段信息
fields = {}
for key, value in attrs.items():
if isinstance(value, Field):
fields[key] = value
# 创建类
cls = super().__new__(mcs, name, bases, attrs)
# 添加元数据
cls._fields = fields
cls._table_name = name.lower()
return cls
class Model(metaclass=ModelMeta):
"""ORM 基类"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def save(self):
# 模拟保存逻辑
print(f"Saving {self.__class__.__name__} to database")
print(f"Fields: {self._fields}")
print(f"Values: {self.__dict__}")
class User(Model):
id = Field(int, primary_key=True)
name = Field(str)
email = Field(str)
user = User(id=1, name="Alice", email="alice@example.com")
user.save()2. 魔法方法进阶
2.1 上下文管理器增强
class Transaction:
"""支持重试的事务管理器"""
def __init__(self, max_retries=3):
self.max_retries = max_retries
self.retries = 0
def __enter__(self):
print("Starting transaction")
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
print("Committing transaction")
return True
if self.retries < self.max_retries:
self.retries += 1
print(f"Retry {self.retries}/{self.max_retries}")
return True # 抑制异常,重试
print("Transaction failed after retries")
return False # 传播异常
# 使用
with Transaction(max_retries=2) as tx:
print("Performing operation")
# 模拟失败
if tx.retries < 1:
raise ValueError("Temporary failure")2.2 高级属性访问
class LazyLoader:
"""延迟加载属性"""
def __init__(self, loader_func):
self.loader_func = loader_func
self._value = None
self._loaded = False
def __get__(self, instance, owner):
if instance is None:
return self
if not self._loaded:
self._value = self.loader_func(instance)
self._loaded = True
return self._value
def __set__(self, instance, value):
self._value = value
self._loaded = True
class HeavyData:
def __init__(self, id):
self.id = id
@LazyLoader
def data(self):
print(f"Loading heavy data for {self.id}")
# 模拟耗时操作
return f"Data for {self.id}"
obj = HeavyData(1)
print(obj.data) # 第一次访问时加载
print(obj.data) # 直接返回缓存2.3 自定义迭代协议
class PaginatedAPI:
"""模拟分页API迭代"""
def __init__(self, base_url, page_size=10):
self.base_url = base_url
self.page_size = page_size
self.current_page = 0
self.total_items = None
def __iter__(self):
return self
def __next__(self):
if self.total_items is None:
# 第一次获取
self._fetch_page(0)
start = self.current_page * self.page_size
if start >= self.total_items:
raise StopIteration
end = min(start + self.page_size, self.total_items)
self.current_page += 1
return self._fetch_page(self.current_page)
def _fetch_page(self, page):
# 模拟API请求
print(f"Fetching page {page}")
# 返回模拟数据
if page == 0:
self.total_items = 25
return list(range(1, 11))
elif page == 1:
return list(range(11, 21))
elif page == 2:
return list(range(21, 26))
return []
# 使用
api = PaginatedAPI("https://api.example.com/data")
for page in api:
print(f"Page data: {page}")3. 装饰器进阶
3.1 带状态的类装饰器
class RateLimiter:
"""方法调用速率限制器"""
def __init__(self, calls_per_second):
self.calls_per_second = calls_per_second
self.last_called = 0
self.lock = threading.Lock()
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self.lock:
current_time = time.time()
elapsed = current_time - self.last_called
if elapsed < 1 / self.calls_per_second:
time.sleep(1 / self.calls_per_second - elapsed)
self.last_called = time.time()
return func(*args, **kwargs)
return wrapper
# 使用
@RateLimiter(calls_per_second=2)
def api_call():
print("API called at", time.time())
# 测试
for _ in range(5):
threading.Thread(target=api_call).start()
time.sleep(0.1)3.2 装饰器工厂
def validate_input(*validators):
"""参数验证装饰器工厂"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 验证位置参数
for i, validator in enumerate(validators):
if i < len(args) and validator:
validator(args[i])
# 验证关键字参数
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for name, param in sig.parameters.items():
if name in bound.arguments and param.annotation != param.empty:
if not isinstance(bound.arguments[name], param.annotation):
raise TypeError(f"Argument '{name}' must be {param.annotation}")
return func(*args, **kwargs)
return wrapper
return decorator
# 自定义验证器
def positive_number(value):
if value <= 0:
raise ValueError("Value must be positive")
# 使用
@validate_input(positive_number, lambda x: x < 100)
def calculate_area(length: float, width: float) -> float:
return length * width
print(calculate_area(5, 10)) # 50
# calculate_area(-5, 10) # 引发 ValueError
# calculate_area(5, "10") # 引发 TypeError4. 混入类进阶
4.1 动态混入
def dynamic_mixin(mixin_class):
"""动态添加混入类到现有类"""
def decorator(cls):
# 创建新类继承自原类和混入类
return type(cls.__name__, (mixin_class, cls), {})
return decorator
# 混入类
class LoggingMixin:
def log(self, message):
print(f"[{self.__class__.__name__}] {message}")
# 原始类
class DataProcessor:
def process(self, data):
return data.upper()
# 动态应用混入
@dynamic_mixin(LoggingMixin)
class EnhancedProcessor(DataProcessor):
def process(self, data):
self.log(f"Processing: {data}")
return super().process(data)
processor = EnhancedProcessor()
result = processor.process("hello") # [EnhancedProcessor] Processing: hello
print(result) # HELLO4.2 接口实现混入
class SerializableMixin:
"""序列化混入"""
def to_json(self):
import json
return json.dumps(self.__dict__)
@classmethod
def from_json(cls, json_str):
data = json.loads(json_str)
return cls(**data)
class CloneableMixin:
"""克隆混入"""
def clone(self):
import copy
return copy.deepcopy(self)
class Entity(SerializableMixin, CloneableMixin):
def __init__(self, id, name):
self.id = id
self.name = name
entity = Entity(1, "Test")
json_str = entity.to_json()
clone = entity.clone()
clone.id = 2
print(json_str) # {"id": 1, "name": "Test"}
print(clone.id) # 25. 协程进阶
5.1 高级协程模式
import asyncio
from contextlib import asynccontextmanager
@asynccontextmanager
async def database_pool():
"""数据库连接池上下文管理器"""
print("Creating connection pool")
pool = [f"Connection-{i}" for i in range(5)]
try:
yield pool
finally:
print("Closing all connections")
pool.clear()
async def execute_query(pool, query):
"""使用连接池执行查询"""
if not pool:
raise RuntimeError("No available connections")
conn = pool.pop()
print(f"Using {conn} to execute: {query}")
await asyncio.sleep(1) # 模拟查询执行
pool.append(conn)
return f"Result of {query}"
async def main():
async with database_pool() as pool:
tasks = [
execute_query(pool, f"SELECT * FROM table_{i}")
for i in range(10)
]
results = await asyncio.gather(*tasks)
for result in results:
print(result)
asyncio.run(main())5.2 协程与线程池结合
import asyncio
from concurrent.futures import ThreadPoolExecutor
def blocking_io():
# 模拟阻塞IO操作
print("Blocking IO started")
time.sleep(2)
return "IO result"
def cpu_bound():
# 模拟CPU密集型操作
print("CPU bound started")
return sum(i * i for i in range(10**6))
async def main():
loop = asyncio.get_running_loop()
# 运行阻塞IO操作
io_result = await loop.run_in_executor(None, blocking_io)
print(io_result)
# 运行CPU密集型操作
with ThreadPoolExecutor() as pool:
cpu_result = await loop.run_in_executor(pool, cpu_bound)
print(cpu_result)
# 并行执行多个任务
results = await asyncio.gather(
loop.run_in_executor(None, blocking_io),
loop.run_in_executor(None, blocking_io),
loop.run_in_executor(pool, cpu_bound)
)
print(results)
asyncio.run(main())5.3 协程状态机
class AsyncStateMachine:
"""基于协程的状态机"""
def __init__(self):
self.state = "INIT"
self.queue = asyncio.Queue()
self.task = asyncio.create_task(self.run())
async def run(self):
while True:
event = await self.queue.get()
print(f"Processing {event} in state {self.state}")
if self.state == "INIT":
if event == "START":
self.state = "RUNNING"
asyncio.create_task(self.worker())
elif self.state == "RUNNING":
if event == "PAUSE":
self.state = "PAUSED"
elif event == "STOP":
self.state = "STOPPED"
break
elif self.state == "PAUSED":
if event == "RESUME":
self.state = "RUNNING"
asyncio.create_task(self.worker())
elif event == "STOP":
self.state = "STOPPED"
break
async def worker(self):
while self.state == "RUNNING":
print("Working...")
await asyncio.sleep(1)
def send(self, event):
self.queue.put_nowait(event)
async def stop(self):
self.send("STOP")
await self.task
# 使用
async def main():
sm = AsyncStateMachine()
sm.send("START")
await asyncio.sleep(2)
sm.send("PAUSE")
await asyncio.sleep(1)
sm.send("RESUME")
await asyncio.sleep(1)
sm.send("STOP")
await sm.stop()
asyncio.run(main())最佳实践
掌握这些高级特性可以显著提升代码质量和开发效率
但务必在适当场景使用,避免不必要的复杂性
评论区