作者:yukkizhang,腾讯 CSIG 测试工程师
本文直接从常用的 Python 单元测试框架出发,分别对几种框架进行了简单的介绍和小结,然后介绍了 Mock 的框架,以及测试报告生成方式,并以具体代码示例进行说明,最后列举了一些常见问题。
若你不想安装或不允许第三方库,那么 unittest
是最好也是唯一的选择。反之,pytest
无疑是最佳选择,众多 Python 开源项目(如大名鼎鼎的 requests)都是使用 pytest
作为单元测试框架。甚至,连 nose2
在官方文档上都建议大家使用 pytest
。我们知道,nose 已经进入了维护模式,取代者是 nose2。相比 nose2,pytest 的生态无疑更具优势,社区的活跃度也更高。
总体来说,unittest 用例格式复杂,兼容性无,插件少,二次开发方便。pytest 更加方便快捷,用例格式简单,可以执行 unittest 风格的测试用例,较好的兼容性,插件丰富。
unittest 中最核心的四个概念是:**test fixture、test case、test suite、test runner **。
编写单元测试时,我们需要编写一个测试类,从unittest.TestCase
继承。
以test
开头的方法就是测试方法,不以test
开头的方法不被认为是测试方法,测试的时候不会被执行。
对每一类测试都需要编写一个test_xxx()
方法。
$ tree .
.
├── README.md
├── requirements.txt
└── src
├── demo
│ └── calculator.py
└── tests
└── demo
├── __init__.py
├── test_calculator_unittest.py
└── test_calculator_unittest_with_fixture.py
class Calculator: def add(self, a, b):
return a + b
def sub(self, a, b):
return a - b
def mul(self, a, b):
return a * b
def div(self, a, b):
return a / b
import unittestfrom src.demo.calculator import Calculator
class TestCalculator(unittest.TestCase):
def test_add(self):
c = Calculator()
result = c.add(3, 5)
self.assertEqual(result, 8)
def test_sub(self):
c = Calculator()
result = c.sub(10, 5)
self.assertEqual(result, 5)
def test_mul(self):
c = Calculator()
result = c.mul(5, 7)
self.assertEqual(result, 35)
def test_div(self):
c = Calculator()
result = c.div(10, 5)
self.assertEqual(result, 2)
if __name__ == '__main__':
unittest.main()
Ran 4 tests in 0.002sOK
基于 unittest 的四个概念的理解,上述简单用例,可以修改为:
import unittestfrom src.demo.calculator import Calculator
class TestCalculatorWithFixture(unittest.TestCase):
# 测试用例前置动作
def setUp(self):
print("test start")
# 测试用例后置动作
def tearDown(self):
print("test end")
def test_add(self):
c = Calculator()
result = c.add(3, 5)
self.assertEqual(result, 8)
def test_sub(self):
c = Calculator()
result = c.sub(10, 5)
self.assertEqual(result, 5)
def test_mul(self):
c = Calculator()
result = c.mul(5, 7)
self.assertEqual(result, 35)
def test_div(self):
c = Calculator()
result = c.div(10, 5)
self.assertEqual(result, 2)
if __name__ == '__main__':
# 创建测试套件
suit = unittest.TestSuite()
suit.addTest(TestCalculatorWithFixture("test_add"))
suit.addTest(TestCalculatorWithFixture("test_sub"))
suit.addTest(TestCalculatorWithFixture("test_mul"))
suit.addTest(TestCalculatorWithFixture("test_div"))
# 创建测试运行器
runner = unittest.TestRunner()
runner.run(suit)
标准库的 unittest 自身不支持参数化测试,可以通过第三方库来支持:parameterized 和 ddt。
其中 parameterized 只需要一个装饰器@parameterized.expand
,ddt 需要三个装饰器@ddt、@data、@unpack
,它们生成的 test 分别有一个名字,ddt 会携带具体的参数信息。
import unittestfrom parameterized import parameterized, param
from src.demo.calculator import Calculator
class TestCalculator(unittest.TestCase):
@parameterized.expand([
param(3, 5, 8),
param(1, 2, 3),
param(2, 2, 4)
])
def test_add(self, num1, num2, total):
c = Calculator()
result = c.add(num1, num2)
self.assertEqual(result, total)
if __name__ == '__main__':
unittest.main()
执行结果:
test_add_0 (__main__.TestCalculator) ... ok
test_add_1 (__main__.TestCalculator) ... ok
test_add_2 (__main__.TestCalculator) ... ok----------------------------------------------------------------------
Ran 3 tests in 0.000s
OK
import unittestfrom ddt import data, unpack, ddt
from src.demo.calculator import Calculator
@ddt
class TestCalculator(unittest.TestCase):
@data((3, 5, 8),(1, 2, 3),(2, 2, 4))
@unpack
def test_add(self, num1, num2, total):
c = Calculator()
result = c.add(num1, num2)
self.assertEqual(result, total)
if __name__ == '__main__':
unittest.main()
执行结果:
test_add_1__3__5__8_ (__main__.TestCalculator) ... ok
test_add_2__1__2__3_ (__main__.TestCalculator) ... ok
test_add_3__2__2__4_ (__main__.TestCalculator) ... ok----------------------------------------------------------------------
Ran 3 tests in 0.000s
OK
unittest 提供了丰富的断言,常用的包括:
assertEqual、assertNotEqual、assertTrue、assertFalse、assertIn、assertNotIn 等。
具体可以直接看源码提供的方法:
nose 已经进入维护模式,从github nose上可以看到,nose 最近的一次代码提交还是在 2016 年 5 月 4 日。
继承 nose 的是 nose2,但要注意的是,nose2 并不支持 nose 的全部功能,它们的区别可以看这里。nose2 的主要目的是扩展 Python 的标准单元测试库 unittest,因此它的定位是“带插件的 unittest”。nose2 提供的插件,例如测试用例加载器,覆盖度报告生成器,并行测试等内置插件和第三方插件,让单元测试变得更加完善。
nose2 的社区没有 pytest 的活跃,要使用高级框架,推荐使用 pytest,因此下文不做过多详述。
参考 unittest 的计算器代码部分。
import nose2from src.demo.calculator import Calculator
def test_add():
c = Calculator()
result = c.add(3, 5)
assert result == 8
def test_sub():
c = Calculator()
result = c.sub(10, 5)
assert result == 5
def test_mul():
c = Calculator()
result = c.mul(5, 7)
assert result == 35
def test_div():
c = Calculator()
result = c.div(10, 5)
assert result == 2
if __name__ == '__main__':
nose2.main()
....
----------------------------------------------------------------------
Ran 4 tests in 0.000sOK
import nose2
from nose2.tools import paramsfrom src.demo.calculator import Calculator
test_data = [
{"nums": (3, 5), "total": 8},
{"nums": (1, 2), "total": 3},
{"nums": (2, 2), "total": 4}
]
@params(*test_data)
def test_add(data):
c = Calculator()
result = c.add(*data['nums'])
assert result == data['total']
if __name__ == '__main__':
nose2.main()
可以通过下面的命令,查看 Pytest 收集到哪些测试用例:
$ py.test --collect-only
参考 unittest 的计算器代码部分。
import pytestfrom src.demo.calculator import Calculator
class TestCalculator():
def test_add(self):
c = Calculator()
result = c.add(3, 5)
assert result == 8
def test_sub(self):
c = Calculator()
result = c.sub(10, 5)
assert result == 5
def test_mul(self):
c = Calculator()
result = c.mul(5, 7)
assert result == 35
def test_div(self):
c = Calculator()
result = c.div(10, 5)
assert result == 2
if __name__ == '__main__':
pytest.main(['-s', 'test_calculator_pytest.py'])
============================= test session starts ==============================
platform darwin -- Python 3.8.3, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
rootdir: python-ut/src/tests/demo
plugins: metadata-1.11.0, html-3.1.1
collected 4 itemstest_calculator_pytest.py ....
============================== 4 passed in 0.01s ===============================
加上 fixture 夹具,有几种方式:
import pytestfrom src.demo.calculator import Calculator
@pytest.fixture()
def set_up():
print("[pytest with fixture] start")
yield
print("[pytest with fixture] end")
class TestCalculator():
def test_add(self, set_up):
c = Calculator()
result = c.add(3, 5)
assert result == 8
def test_sub(self, set_up):
c = Calculator()
result = c.sub(10, 5)
assert result == 5
@pytest.mark.usefixtures("set_up")
def test_mul(self):
c = Calculator()
result = c.mul(5, 7)
assert result == 35
@pytest.mark.usefixtures("set_up")
def test_div(self):
c = Calculator()
result = c.div(10, 5)
assert result == 2
if __name__ == '__main__':
pytest.main(['-s', 'test_calculator_pytest_with_fixture.py'])
执行结果:
============================= test session starts ==============================
platform darwin -- Python 3.8.3, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
rootdir: python-ut/src/tests/demo
plugins: metadata-1.11.0, html-3.1.1
collected 4 itemstest_calculator_pytest_with_fixture.py [pytest with fixture] start
.[pytest with fixture] end
[pytest with fixture] start
.[pytest with fixture] end
[pytest with fixture] start
.[pytest with fixture] end
[pytest with fixture] start
.[pytest with fixture] end
============================== 4 passed in 0.01s ===============================
@pytest.mark.parametrize("num1", [3, 5, 8])
@pytest.mark.parametrize("num1, num2, total", [(3, 5, 8), (1, 2, 3), (2, 2, 4)])
@pytest.mark.parametrize
装饰测试类时,会将数据集合传递给类的所有测试用例方法@pytest.mark.parametrize
,当参数化有多个装饰器时,用例数是 N*M...import pytestfrom src.demo.calculator import Calculator
class TestCalculator():
@pytest.mark.parametrize("num1, num2, total", [(3, 5, 8), (1, 2, 3), (2, 2, 4)])
def test_add(self, num1, num2, total):
c = Calculator()
result = c.add(num1, num2)
assert result == total
if __name__ == '__main__':
pytest.main(['test_calculator_pytest_with_parameterize.py'])
执行结果:
============================= test session starts ==============================
platform darwin -- Python 3.8.3, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
rootdir: python-ut/src/tests/demo
plugins: metadata-1.11.0, html-3.1.1
collected 3 itemstest_calculator_pytest_with_paramtrize.py ...
============================== 3 passed in 0.01s ===============================
class TestCalculator():
@pytest.mark.parametrize("num1, num2, total", [
pytest.param(5, 1, 4, marks=pytest.mark.passed),
pytest.param(5, 2, 4, marks=pytest.mark.fail),
(5, 4, 1)
])
def test_sub(self, num1, num2, total):
c = Calculator()
result = c.sub(num1, num2)
assert result == totalif __name__ == '__main__':
pytest.main(['test_calculator_pytest_with_parameterize.py'])
执行结果:
============================= test session starts ==============================
platform darwin -- Python 3.8.3, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
rootdir: python-ut/src/tests/demo
plugins: metadata-1.11.0, html-3.1.1
collected 3 itemstest_calculator_pytest_with_paramtrize.py .F. [100%]
=================================== FAILURES ===================================
________________________ TestCalculator.test_sub[5-2-4] ________________________
self = <demo.test_calculator_pytest_with_paramtrize.TestCalculator object at 0x110813d00>
num1 = 5, num2 = 2, total = 4
@pytest.mark.parametrize("num1, num2, total", [
pytest.param(5, 1, 4, marks=pytest.mark.passed),
pytest.param(5, 2, 4, marks=pytest.mark.fail),
(5, 4, 1)
])
def test_sub(self, num1, num2, total):
c = Calculator()
result = c.sub(num1, num2)
> assert result == total
E assert 3 == 4
test_calculator_pytest_with_paramtrize.py:21: AssertionError
=========================== short test summary info ============================
FAILED test_calculator_pytest_with_paramtrize.py::TestCalculator::test_sub[5-2-4]
=================== 1 failed, 2 passed, 2 warnings in 0.04s ====================
在 unittest 单元测试框架中提供了丰富的断言方法,例如 assertEqual()、assertIn()、assertTrue()、assertIs()等,而 pytest 单元测试框架中并没提供特殊的断言方法,而是直接使用 python 的 assert 进行断言。
assert 可以使用==
、!=
、<
、>
、>=
、<=
等符号来比较相等、不相等、小于、大于、大于等于和小于等于。
断言包含和不包含,使用assert a in b
和assert a not in b
断言真假,使用assert condition
和assert not condition
断言异常,使用 pytest.raise 获取信息
# 详细断言异常
def test_zero_division_long():
with pytest.raises(ZeroDivisionError) as excinfo:
1 / 0 # 断言异常类型 type
assert excinfo.type == ZeroDivisionError
# 断言异常 value 值
assert "division by zero" in str(excinfo.value)
需要安装额外的插件 pytest-rerunfailures
import pytest@pytest.mark.flaky(reruns=5)
def test_example():
import random
assert random.choice([True, False, False])
执行结果:
collecting ... collected 1 item11_reruns.py::test_example RERUN [100%]
11_reruns.py::test_example PASSED [100%]
========================= 1 passed, 1 rerun in 0.05s ==========================
mock 原是 python 的第三方库,python3 以后 mock 模块已经整合到了 unittest 测试框架中。
如果使用的是 python3.3 以后版本,那么不用单独安装,使用的时候在文件开头引入from unittest import mock
即可。
如果使用的是 python2,需要先pip install mock
安装后再import mock
即可。
import unittest
from unittest import mockfrom src.demo.calculator import Calculator
def multiple(a, b):
return a * b
class TestCalculator(unittest.TestCase):
@mock.patch('test_calculator_mock.multiple')
def test_function_multiple(self, mock_multiple):
mock_return = 1
mock_multiple.return_value = mock_return
result = multiple(3, 5)
self.assertEqual(result, mock_return)
if __name__ == '__main__':
unittest.main()
分别给出了普通写法和注解写法,以及 side_effect 关键参数的效果案例。
import unittest
from unittest import mockfrom src.demo.calculator import Calculator
class TestCalculator(unittest.TestCase):
def test_add(self):
c = Calculator()
mock_return = 10
c.add = mock.Mock(return_value=mock_return)
result = c.add(3, 5)
self.assertEqual(result, mock_return)
def test_add_with_side_effect(self):
c = Calculator()
mock_return = 10
# 传递side_effect关键字参数, 会覆盖return_value参数值, 使用真实的add方法测试
c.add = mock.Mock(return_value=mock_return, side_effect=c.add)
result = c.add(3, 5)
self.assertEqual(result, 8)
@mock.patch.object(Calculator, 'add')
def test_add_with_annotation(self, mock_add):
c = Calculator()
mock_return = 10
mock_add.return_value = mock_return
result = c.add(3, 5)
self.assertEqual(result, mock_return)
if __name__ == '__main__':
unittest.main()
import unittest
from unittest import mockfrom src.demo.calculator import Calculator
class TestCalculator(unittest.TestCase):
@mock.patch.object(Calculator, 'add')
def test_add_with_different_return(self, mock_add):
c = Calculator()
mock_return = [10, 8]
mock_add.side_effect = mock_return
result1 = c.add(3, 5)
result2 = c.add(3, 5)
self.assertEqual(result1, mock_return[0])
self.assertEqual(result2, mock_return[1])
if __name__ == '__main__':
unittest.main()
import unittest
from unittest import mockfrom src.demo.calculator import Calculator
# 被调用函数
def multiple(a, b):
return a * b
# 实际调用函数
def is_error(a, b):
try:
return multiple(a, b)
except Exception as e:
return -1
class TestCalculator(unittest.TestCase):
@mock.patch('test_calculator_mock.multiple')
def test_function_multiple_exception(self, mock_multiple):
mock_multiple.side_effect = Exception
result = is_error(3, 5)
self.assertEqual(result, -1)
if __name__ == '__main__':
unittest.main()
import unittest
from unittest import mockfrom src.demo.calculator import Calculator
def multiple(a, b):
return a * b
class TestCalculator(unittest.TestCase):
# z'h
@mock.patch.object(Calculator, 'add')
@mock.patch('test_calculator_mock.multiple')
def test_both(self, mock_multiple, mock_add):
c = Calculator()
mock_add.return_value = 1
mock_multiple.return_value = 2
self.assertEqual(c.add(3, 5), 1)
self.assertEqual(multiple(3, 5), 2)
if __name__ == '__main__':
unittest.main()
如果项目本身使用的框架是 pytest,则 Mock 更建议使用 pytest-mock 这个插件,它提供了一个名为 mocker 的 fixture,仅在当前测试 funciton 或 method 生效,而不用自行包装。
mocker 和 mock.patch 有相同的 api,支持相同的参数。
import pytestfrom src.demo.calculator import Calculator
class TestCalculator():
def test_add(self, mocker):
c = Calculator()
mock_return = 10
mocker.patch.object(c, 'add', return_value=mock_return)
result = c.add(3, 5)
assert result == mock_return
if __name__ == '__main__':
pytest.main(['-s', 'test_calculator_pytest_mock.py'])
class ForTest:
field = 'origin' def method():
pass
def test_for_test(mocker):
test = ForTest()
# 方法
mock_method = mocker.patch.object(test, 'method')
test.method()
# 检查行为
assert mock_method.called
# 域
assert 'origin' == test.field
mocker.patch.object(test, 'field', 'mocked')
# 检查结果
assert 'mocked' == test.field
monkeypatch 是 pytest 框架内置的固件,有时候,测试用例需要调用某些依赖于全局配置的功能,或者这些功能本身又调用了某些不容易测试的代码(例如:网络接入)。monkeypatch 提供了一些方法,用于安全地修补和模拟测试中的功能:
monkeypatch.setattr(obj, name, value, raising=True)
monkeypatch.delattr(obj, name, raising=True)
monkeypatch.setitem(mapping, name, value)
monkeypatch.delitem(obj, name, raising=True)
monkeypatch.setenv(name, value, prepend=False)
monkeypatch.delenv(name, raising=True)
monkeypatch.syspath_prepend(path)
monkeypatch.chdir(path)
主要考虑以下情形:
coverage 是 Python 推荐使用的覆盖率统计工具。
pytest-cov 是 pytest 的插件,它可以让你在 pytest 中使用 cpverage.py。
HtmlTestRunner,需要在代码里面写入一点配置,但是报告生成比较美观。
coverage 和 pytest-cov 只需要配置,就可直接使用,不需要测试代码配合。
pip install coverage
详情可参考:coverage
coverage run -m unittest discover
运行结束之后,会生成一个覆盖率统计结果文件(data file).coverage
文件,在 pycharm 里可识别为一个数据库:
1.3. 1 report
coverage report -m
执行结果如下:
$ coverage report -m
Name Stmts Miss Cover Missing
---------------------------------------------------------------------------------------------
src/tests/demo/test_calculator_pytest_with_fixture.py 28 16 43% 8-10, 15-17, 20-22, 26-28, 32-34, 38
src/tests/demo/test_calculator_pytest_with_parameterize.py 15 7 53% 9-11, 19-21, 25
src/tests/demo/test_calculator_unittest.py 22 1 95% 31
src/tests/demo/test_calculator_unittest_with_ddt.py 13 1 92% 18
1.3.2 html
会生成 htmlcov/index.html 文件,在浏览器查看:
coverage html
点击各个 py 文件,可以查看详细情况。
pip install html-testRunner
详细说明可参考HtmlTestRunner。
在代码中加上 HTMLTestRunner,如下
import HtmlTestRunner# some tests here
if __name__ == '__main__':
unittest.main(testRunner=HtmlTestRunner.HTMLTestRunner())
如果是在测试套件中运行,换成 HTMLTestRunner 即可:
# 创建测试运行器
# runner = unittest.TestRunner()
runner = HTMLTestRunner()
runner.run(suit)
默认会生成reports/
文件夹,按照时间显示报告:
pip install pytest-cov
详细可参考pytest-cov
pytest --cov --cov-report=html
或者指定目录:
pytest --cov=src --cov-report=html
会生成 htmlcov/index.html 文件,在浏览器查看,类似于 coverage 的报告。
如果出现不了报告,pycharm 运行的时候,记得选择 python,而不是 Python tests
可选择 unittest 和 pytest 为默认 runner
可显示覆盖率窗口:
一个简单的博客系统,包含:
创建文章
获取文章
获取文章列表
├── README.md
├── requirements.txt
└── src
├── blog
│ ├── __init__.py
│ ├── app.py
│ ├── commands.py
│ ├── database.db
│ ├── init_db.py
│ ├── models.py
│ └── queries.py
└── tests
└── blog
├── __init__.py
├── conftest.py
├── schemas
│ ├── Article.json
│ ├── ArticleList.json
│ └── __init__.py
├── test_app.py
├── test_commands.py
└── test_queries.py
models.py 如下:
import os
import sqlite3
import uuid
from typing import Listfrom pydantic import BaseModel, EmailStr, Field
class NotFound(Exception):
pass
class Article(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
author: EmailStr
title: str
content: str
@classmethod
def get_by_id(cls, article_id: str):
con = sqlite3.connect(os.getenv('DATABASE_NAME', 'database.db'))
con.row_factory = sqlite3.Row
cur = con.cursor()
cur.execute("SELECT * FROM articles WHERE id=?", (article_id,))
record = cur.fetchone()
if record is None:
raise NotFound
article = cls(**record) # Row can be unpacked as dict
con.close()
return article
@classmethod
def get_by_title(cls, title: str):
con = sqlite3.connect(os.getenv('DATABASE_NAME', 'database.db'))
con.row_factory = sqlite3.Row
cur = con.cursor()
cur.execute("SELECT * FROM articles WHERE title = ?", (title,))
record = cur.fetchone()
if record is None:
raise NotFound
article = cls(**record) # Row can be unpacked as dict
con.close()
return article
@classmethod
def list(cls) -> List['Article']:
con = sqlite3.connect(os.getenv('DATABASE_NAME', 'database.db'))
con.row_factory = sqlite3.Row
cur = con.cursor()
cur.execute("SELECT * FROM articles")
records = cur.fetchall()
articles = [cls(**record) for record in records]
con.close()
return articles
def save(self) -> 'Article':
with sqlite3.connect(os.getenv('DATABASE_NAME', 'database.db')) as con:
cur = con.cursor()
cur.execute(
"INSERT INTO articles (id,author,title,content) VALUES(?, ?, ?, ?)",
(self.id, self.author, self.title, self.content)
)
con.commit()
return self
@classmethod
def create_table(cls, database_name='database.db'):
conn = sqlite3.connect(database_name)
conn.execute(
'CREATE TABLE IF NOT EXISTS articles (id TEXT, author TEXT, title TEXT, content TEXT)'
)
conn.close()
commands.py 如下:
from pydantic import BaseModel, EmailStrfrom src.blog.models import Article, NotFound
class AlreadyExists(Exception):
pass
class CreateArticleCommand(BaseModel):
author: EmailStr
title: str
content: str
def execute(self) -> Article:
try:
Article.get_by_title(self.title)
raise AlreadyExists
except NotFound:
pass
article = Article(
author=self.author,
title=self.title,
content=self.title
).save()
return article
单元测试 test_commands.py:
import pytestfrom src.blog.commands import CreateArticleCommand, AlreadyExists
from src.blog.models import Article
def test_create_article():
"""
GIVEN CreateArticleCommand with a valid properties author, title and content
WHEN the execute method is called
THEN a new Article must exist in the database with the same attributes
"""
cmd = CreateArticleCommand(
author='[email protected]',
title='New Article',
content='Super awesome article'
)
article = cmd.execute()
db_article = Article.get_by_id(article.id)
assert db_article.id == article.id
assert db_article.author == article.author
assert db_article.title == article.title
assert db_article.content == article.content
def test_create_article_with_mock(monkeypatch):
"""
GIVEN CreateArticleCommand with valid properties author, title and content
WHEN the execute method is called
THEN a new Article must exist in the database with same attributes
"""
article = Article(
author='[email protected]',
title='New Article',
content='Super awesome article'
)
monkeypatch.setattr(
Article,
'save',
lambda self: article
)
cmd = CreateArticleCommand(
author='[email protected]',
title='New Article',
content='Super awesome article'
)
db_article = cmd.execute()
assert db_article.id == article.id
assert db_article.author == article.author
assert db_article.title == article.title
assert db_article.content == article.content
def test_create_article_already_exists():
"""
GIVEN CreateArticleCommand with a title of some article in database
WHEN the execute method is called
THEN the AlreadyExists exception must be raised
"""
Article(
author='[email protected]',
title='New Article',
content='Super extra awesome article'
).save()
cmd = CreateArticleCommand(
author='[email protected]',
title='New Article',
content='Super awesome article'
)
with pytest.raises(AlreadyExists):
cmd.execute()
当多次运行时候,需要清理数据库,那么需要使用到用例前置和后置:
confest.py:
import os
import tempfileimport pytest
from src.blog.models import Article
@pytest.fixture(autouse=True)
def database():
_, file_name = tempfile.mkstemp()
os.environ['DATABASE_NAME'] = file_name
Article.create_table(database_name=file_name)
yield
os.unlink(file_name)
再次运行,执行结果:
$ python3 -m pytest src/tests/blog/test_commands.py
=================== test session starts ======================
platform darwin -- Python 3.8.3, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
rootdir: python-ut
plugins: metadata-1.11.0, html-3.1.1, mock-3.5.1
collected 3 itemssrc/tests/blog/test_commands.py ... [100%]
===================== 3 passed in 0.02s =======================
queries.py:
from typing import Listfrom pydantic import BaseModel
from src.blog.models import Article
class ListArticlesQuery(BaseModel):
def execute(self) -> List[Article]:
articles = Article.list()
return articles
单元测试 test_queries.py:
from src.blog.models import Article
from src.blog.queries import ListArticlesQuery, GetArticleByIDQuerydef test_list_articles():
"""
GIVEN 2 articles stored in the database
WHEN the execute method is called
THEN it should return 2 articles
"""
Article(
author='[email protected]',
title='New Article',
content='Super extra awesome article'
).save()
Article(
author='[email protected]',
title='Another Article',
content='Super awesome article'
).save()
query = ListArticlesQuery()
assert len(query.execute()) == 2
queries.py 里面加入:
class GetArticleByIDQuery(BaseModel):
id: str def execute(self) -> Article:
article = Article.get_by_id(self.id)
return article
单元测试 test_queries.py 里加入:
def test_get_article_by_id():
"""
GIVEN ID of article stored in the database
WHEN the execute method is called on GetArticleByIDQuery with id set
THEN it should return the article with the same id
"""
article = Article(
author='[email protected]',
title='New Article',
content='Super extra awesome article'
).save() query = GetArticleByIDQuery(
id=article.id
)
assert query.execute().id == article.id
应用入口 app.py:
from flask import Flask, jsonify, requestfrom src.blog.commands import CreateArticleCommand
from src.blog.queries import GetArticleByIDQuery, ListArticlesQuery
from pydantic import ValidationError
app = Flask(__name__)
@app.route('/articles/', methods=['POST'])
def create_article():
cmd = CreateArticleCommand(
**request.json
)
return jsonify(cmd.execute().dict())
@app.route('/articles/<article_id>/', methods=['GET'])
def get_article(article_id):
query = GetArticleByIDQuery(
id=article_id
)
return jsonify(query.execute().dict())
@app.route('/articles/', methods=['GET'])
def list_articles():
query = ListArticlesQuery()
records = [record.dict() for record in query.execute()]
return jsonify(records)
@app.errorhandler(ValidationError)
def handle_validation_exception(error):
response = jsonify(error.errors())
response.status_code = 400
return response
if __name__ == '__main__':
app.run()
暴露 json schema,校验响应 payload:
Article.json
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Article",
"type": "object",
"properties": {
"id": {
"type": "string"
},
"author": {
"type": "string"
},
"title": {
"type": "string"
},
"content": {
"type": "string"
}
},
"required": [
"id",
"author",
"title",
"content"
]
}
ArticleList.json
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "ArticleList",
"type": "array",
"items": {
"$ref": "file:Article.json"
}
}
从应用本身,串起来整个流程的测试,测试 test_app.py:
import json
import pathlibimport pytest
from jsonschema import validate, RefResolver
from src.blog.app import app
from src.blog.models import Article
@pytest.fixture
def client():
app.config['TESTING'] = True
with app.test_client() as client:
yield client
def validate_payload(payload, schema_name):
"""
Validate payload with selected schema
"""
schemas_dir = str(
f'{pathlib.Path(__file__).parent.absolute()}/schemas'
)
schema = json.loads(pathlib.Path(f'{schemas_dir}/{schema_name}').read_text())
validate(
payload,
schema,
resolver=RefResolver(
'file://' + str(pathlib.Path(f'{schemas_dir}/{schema_name}').absolute()),
schema # it's used to resolve file: inside schemas correctly
)
)
def test_create_article(client):
"""
GIVEN request data for new article
WHEN endpoint /articles/ is called
THEN it should return Article in json format matching schema
"""
data = {
'author': '[email protected]',
'title': 'New Article',
'content': 'Some extra awesome content'
}
response = client.post(
'/articles/',
data=json.dumps(
data
),
content_type='application/json',
)
validate_payload(response.json, 'Article.json')
def test_get_article(client):
"""
GIVEN ID of article stored in the database
WHEN endpoint /articles/<id-of-article>/ is called
THEN it should return Article in json format matching schema
"""
article = Article(
author='[email protected]',
title='New Article',
content='Super extra awesome article'
).save()
response = client.get(
f'/articles/{article.id}/',
content_type='application/json',
)
validate_payload(response.json, 'Article.json')
def test_list_articles(client):
"""
GIVEN articles stored in the database
WHEN endpoint /articles/ is called
THEN it should return list of Article in json format matching schema
"""
Article(
author='[email protected]',
title='New Article',
content='Super extra awesome article'
).save()
response = client.get(
'/articles/',
content_type='application/json',
)
validate_payload(response.json, 'ArticleList.json')
@pytest.mark.parametrize(
'data',
[
{
'author': 'John Doe',
'title': 'New Article',
'content': 'Some extra awesome content'
},
{
'author': 'John Doe',
'title': 'New Article',
},
{
'author': 'John Doe',
'title': None,
'content': 'Some extra awesome content'
}
]
)
def test_create_article_bad_request(client, data):
"""
GIVEN request data with invalid values or missing attributes
WHEN endpoint /create-article/ is called
THEN it should return status 400 and JSON body
"""
response = client.post(
'/articles/',
data=json.dumps(
data
),
content_type='application/json',
)
assert response.status_code == 400
assert response.json is not None
自此,上面的 web 小应用基本可以完成,包含了基本的服务层单元测试、数据库模拟、mock 创建文章以及参数化请求验证。
Python 的单元测试框架中,Python 库本身提供了 unittest,也有第三方框架进行了封装。原生的库插件少,二次开发非常方便。第三方框架融合了不少插件,上手简单。
Python 属于脚本语言,不像编译型语言那样先将程序编译成二进制再运行,而是动态地逐行解释运行,虽然其本身的结构灵活多变,但是仍然不妨碍我们用单元测试保证其质量、权衡其设计、设置其有形和无形的约束,为开发保驾护航。
腾讯看点商业化中心招聘信息