在DRF中,source
参数用于在序列化器字段和模型字段之间建立映射关系。本文将从用途和示例出发,帮助你理解其作用。
它的主要用途如下:
- 字段重命名
- 访问嵌套模型字段
- 处理多对多关系
1. 字段重命名
假如我们有如下model:
from django.db import models
class Order(models.Model):
name = models.CharField(max_length=255)
def __str__(self):
return self.name
当前端需要order_name
字段名来替代name
时,我们可以创建如下的序列化器:
from rest_framework import serializers
from api.models import Order
class OrderSerializer(serializers.ModelSerializer):
order_name = serializers.CharField(source='name')
class Meta:
model = Order
fields = ['id', 'order_name']
在上面的示例中,我们使用source
字段创建了序列化器字段order_name
和模型字段name
之间的映射关系:当我们在进行反序列化的时候,order_name
将填充到name
中,当进行序列化时从查询集中获取的name
字段将被用于填充order_name
。
在下面的测试中,观察vaildate_data
和序列化的结果,你可以清晰的看到这一点:
from rest_framework.test import APITestCase
from .serializers import OrderSerializer
class TestSource(APITestCase):
def test_source(self):
user = User.objects.create_user(username='用户1', password='test')
data = {
'order_name': "测试订单",
'notes': "用于测试",
}
s = OrderSerializer(data=data)
s.is_valid(raise_exception=True)
print(s.validated_data)
# {'notes': '用于测试', 'name': '测试订单'}
# name由order_name填充
s.save(created_by=user)
data = OrderSerializer(Order.objects.get(pk=1)).data
print(data)
# {'id': 1, 'order_name': '测试订单'}
# order_name由name填充
2.访问嵌套(关系)模型字段
现在为刚刚的Order
模型添加一个外键:
from django.db import models
from django.contrib.auth.models import User
class Order(models.Model):
name = models.CharField(max_length=255)
notes = models.CharField(max_length=255)
created_by = models.ForeignKey(User, on_delete=models.CASCADE, null=True, blank=True)
def __str__(self):
return self.name
更改序列化器如下:
from rest_framework import serializers
from api.models import Order
class OrderSerializer(serializers.ModelSerializer):
created_by = serializers.StringRelatedField(read_only=True)
order_name = serializers.CharField(source='name')
create_user_name = serializers.CharField(source='created_by.username', read_only=True)
class Meta:
model = Order
fields = ['id', 'notes', 'created_by', 'create_user_name', 'order_name']
测试一下:
class TestSource(APITestCase):
def test_source(self):
user = User.objects.create_user(username='用户1', password='test')
data = {
'order_name': "测试订单",
'notes': "用于测试",
}
s = OrderSerializer(data=data)
s.is_valid(raise_exception=True)
print(s.validated_data)
# {'notes': '用于测试', 'name': '测试订单'}
s.save(created_by=user)
data = OrderSerializer(Order.objects.get(pk=1)).data
print(data)
# {'id': 1, 'notes': '用于测试', 'created_by': '用户1', 'create_user_name': '用户1', 'order_name': '测试订单'}
3.处理多对多关系
模型如下:
from django.db import models
class Blog(models.Model):
title = models.CharField(max_length=255)
detail = models.TextField()
class BlogGroup(models.Model):
name = models.CharField(max_length=255)
blogs = models.ManyToManyField(Blog, related_name='groups')
def __str__(self):
return self.name
假如有这样的需求:我们希望在创建blog的时候可以选择传入一组group_ids
,模型中Blog
和Group
是多对多关系,如果使用重写ModelSerializer
的create
方法显得很麻烦,此时借助source
我们可以很轻松的完成这项工作:
from rest_framework import serializers
from .models import Blog, BlogGroup
class BlogSer(serializers.ModelSerializer):
groups = serializers.StringRelatedField(read_only=True, many=True)
group_ids = serializers.PrimaryKeyRelatedField(queryset=BlogGroup.objects.all(), write_only=True, source='groups',
many=True)
class Meta:
model = Blog
fields = ['id', 'title', 'detail', 'groups', 'group_ids']
测试一下:
from rest_framework.test import APITestCase
from .serializers import BlogSer
from .models import Blog, BlogGroup
class TestSource(APITestCase):
def test_source_get(self):
BlogGroup.objects.create(name='测试组1')
BlogGroup.objects.create(name='测试组2')
data = {
'title': 'blog1',
'detail': 'my blog detail.',
'group_ids': []
}
s = BlogSer(data=data)
s.is_valid(raise_exception=True)
print(s.validated_data)
# {'title': 'blog1', 'detail': 'my blog detail.', 'groups': [<BlogGroup: 测试组1>, <BlogGroup: 测试组2>]}
s.save()
print(BlogSer(Blog.objects.get(pk=1)).data)
# {'id': 1, 'title': 'blog1', 'detail': 'my blog detail.', 'groups': ['测试组1', '测试组2']}
上述的做法缺点也很明显:如果要创建的数据较多,则会导致大量的查询:
>>> data = {
'title': 'blog1',
'detail': 'my blog detail.',
'group_ids': [1, 2, 3, 4, 5, 6, 7, 8, 9]
}
>>> s = BlogSer(data=data)
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" LIMIT 21; args=(); alias=default
>>> s.is_valid(raise_exception=True)
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 1 LIMIT 21; args=(1,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 2 LIMIT 21; args=(2,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 3 LIMIT 21; args=(3,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 4 LIMIT 21; args=(4,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 5 LIMIT 21; args=(5,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 6 LIMIT 21; args=(6,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 7 LIMIT 21; args=(7,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 8 LIMIT 21; args=(8,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 9 LIMIT 21; args=(9,); alias=default
True
>>> (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" LIMIT 21; args=(); alias=default
如果你比较在意性能,或者明确知道此接口可能经常批量的处理大量的数据,还是自己重写create
、update
方法为好:
from rest_framework import serializers
from .models import Blog, BlogGroup
class BlogSer(serializers.ModelSerializer):
groups = serializers.StringRelatedField(read_only=True, many=True)
group_ids = serializers.ListField(child=serializers.IntegerField(), write_only=True)
class Meta:
model = Blog
fields = ['id', 'title', 'detail', 'groups', 'group_ids']
def validate_group_ids(self, value):
existing_ids = set(BlogGroup.objects.filter(id__in=value).values_list('id', flat=True))
invalid_ids = set(value) - existing_ids
if invalid_ids:
raise serializers.ValidationError(f"Invalid group ids: {', '.join(map(str, invalid_ids))}")
return value
def create(self, validated_data):
group_ids = validated_data.pop('group_ids')
blog = Blog.objects.create(**validated_data)
blog.groups.set(BlogGroup.objects.filter(id__in=group_ids))
return blog
def update(self, instance, validated_data):
group_ids = validated_data.pop('group_ids', None)
for attr, value in validated_data.items():
setattr(instance, attr, value)
if group_ids is not None:
instance.groups.set(BlogGroup.objects.filter(id__in=group_ids))
instance.save()
return instance