This page looks best with JavaScript enabled

TDD Approach to Create an Authentication System With FastAPI Part 3

 ·   ·  ☕ 15 min read

Introduction

If you are following this series from start, you know that we can already create user from our API. But we have our tests failing right now, which is kind of smelly. In this post, we takle with the problem in our hand and discuss some methods we can do it with. One of them is dependency injection.

Don’t panic, dependency injection is just a fancy word for something very common. Given that you already are familiar with Object Oriented Programming, you’d catch up in no time.

Series Index

We have already seen:

The solve the problem at hand, first we get to understand what the problem is…

Why is our unit test failing?

First of all, our test depends on a running database. Thus the test we are having right now is less of a unit test and more of an integration test.

Let’s see at the output of our last test:

$ pytest
================================================== test session starts ==================================================
platform linux -- Python 3.7.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /efs/repos/fastauth
plugins: anyio-3.3.4
collected 5 items                                                                                                       

tests/test_main.py .                                                                                              [ 20%]
tests/test_users.py ...F                                                                                          [100%]

======================================================= FAILURES ========================================================
__________________________ TestUserRegistration.test_post_request_with_proper_body_returns_201 __________________________

self = <tests.test_users.TestUserRegistration object at 0x7fa6570ef8d0>

    def test_post_request_with_proper_body_returns_201(self):
        response = client.post(
            "/users/register",
            json={"username": "santosh", "password": "sntsh", "fullname": "Santosh Kumar"}
        )
>       assert response.status_code == 201
E       assert 409 == 201
E        +  where 409 = <Response [409]>.status_code

tests/test_users.py:39: AssertionError
================================================ short test summary info ================================================
FAILED tests/test_users.py::TestUserRegistration::test_post_request_with_proper_body_returns_201 - assert 409 == 201
============================================== 1 failed, 4 passed in 1.19s ==============================================

Wooo… looks like we are getting a 409 response instead of 201. We wrote that code with our hand ourselves. Remember this line?

1
2
if db_user:
    raise HTTPException(status_code=409, detail="Username already registered")

This is our culprit. And why we are getting this response? Yeah, because Username already registered.

This is happening because on the production database, the user already exists. Remember the word production I used here? If this application was public facing, we must not be testing this application against this one. Even if are doing integration testing, there should be a staging server we should be testing with.

There are quite a few ways we can test our user API code here. We’ll see both of them in brief. Here are both:

  1. Mock the API calls
  2. Inject dependency on the go while execution

Both have their own pros and cons and use cases.

What we want from our tests?

We can handle the situation in our hands with different methods. But what we choose depends on what we want. At the current state of the application we are developing, we can create a temporary database in the postgres server for test and then delete it later on. That is totally achievable using pytest. But that is not something that I’m looking for.

What I’m looking for is…

  • To be able to run my test on CI/CD server.

Which method we choose depends on what satisfies above goal. Both of the method I’ve described can be used to do so. Let’s see both of them, and then later will choose which one to go with.

Mock the API calls

This is one of the sneakiest way we can test our code in hand here. Let me go ahead and explain you what’s going on here.

I’ll take an example of Calculator class for basic understanding. This concept will be easy to explain that way. Later on, we will adapt the same code.

calculator.py

1
2
3
4
5
6
import time

class Calculator:
    def add(self, x, y):
        time.sleep(5)
        return x + y

Our Calculator.add simply returns the sum of x and y. But it does it after waiting for 10 seconds.

Next is our test_calculator.py, which bypasses the above function by calling a mock instead of calling it directly.

test_calculator.py

1
2
3
4
5
6
7
8
9
from unittest import TestCase
from unittest.mock import patch


class TestCalculator(TestCase):
    @patch('calculator.Calculator.add', return_value=9)
    def test_add(self, add):
        expected = add(2, 3)
        self.assertEqual(expected, 9)

You don’t need to go too far to experiment with mocking and stuff. In above test file, we are adding 2 and 3, and we are expecting to return 9. Isn’t that what you were wondering about?

But this test will pass.

$ python -m unittest
.
----------------------------------------------------------------------
Ran 1 tests in 0.002s

OK

To explain the way part, I’ll have to go to ahead and explan the API of unitest.mock.patch decorator. Believe it or not, it has been there in Python 3 since 2012. And when I went to see the proposal of how it was added, I found out that mock library has been there even before it was added to standard library. It started in 2007. I was in class 6th back then.

Okay, so I get a little bit off track there, but here it is..

  1. The first argument to unitest.mock.patch is the target, meaning that the class or function for which we are creating stub for. Please note that whatever your input to this argument is should be importable otherwise in the current module.
  2. The second argument we have passed it the return_value, this tells the patch function to return a value of 9 whenever we call calculator.Calculator.add.

Under the hood, @patch decorator uses mock.Mock class.

Using side_effect in @patch

We have seen how we can tell the patch function to return a hardcoded value when we ask it to run a function/class. But this is mostly boring. We want more dynamic behavior. What if we had another function which replicated our add function?

That’s what we’ll do in our next step. Let’s create a function without time.sleep and

Here is a modified version of the test:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
from unittest import TestCase
from unittest.mock import patch


def mock_add(x, y):
    return x + y


class TestCalculator(TestCase):
    @patch('calculator.Calculator.add', side_effect=mock_add)
    def test_add(self, add):
        expected = add(2, 3)
        self.assertEqual(expected, 5)

This is more or less the same code, but we have a dynamic behavior now. We can pass our own values to the add function and mock_add function will be executed every time.

Mocking our TestClient.post

Let’s put the knowledge we have learned to use. We have our tests, and one of them are failing.

1
2
3
4
5
6
    def test_post_request_with_proper_body_returns_201(self):
        response = client.post(
            "/users/register",
            json={"username": "santosh", "password": "sntsh", "fullname": "Santosh Kumar"}
        )
        assert response.status_code == 201

Here I have modified the test to call our mock version of TestClient.post. Here’s the diff:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
+from unittest.mock import patch
+
 import pytest
 from fastapi.testclient import TestClient
 
@@ -5,6 +7,17 @@ from main import app
 
 client = TestClient(app)
 
+def mock_post(endpoint, json):
+    field_names = json.keys()
+
+    good_mock_response = {'status_code': 201}
+    bad_mock_response = {'status_code': 500}
+
+    if endpoint == "/users/register":
+        if all(field in field_names for field in ['username', 'password', 'fullname']):
+            return good_mock_response
+
+    return bad_mock_response
 
 class TestUserRegistration:
     """TestUserRegistration tests /users/register"""
@@ -27,9 +40,12 @@ class TestUserRegistration:
         )
         assert response.status_code == 422
 
-    def test_post_request_with_proper_body_returns_201(self):
-        response = client.post(
+    @patch('fastapi.testclient.TestClient.post', side_effect=mock_post)
+    def test_post_request_with_proper_body_returns_201(self, post):
+        response = post(
             "/users/register",
             json={"username": "santosh", "password": "sntsh", "fullname": "Santosh Kumar"}
         )
-        assert response.status_code == 201
+        assert response['status_code'] == 201
+

I have written a similar logic to the original one in mock_post. I am testing for username, password and fullname in the json. If it is present, then returning status code 201. Then I’m looking for this status code in the test.

Now the test passes no matter how many times I run it.

$ pytest --no-header --no-summary
==================== test session starts ====================
collected 5 items                                           

tests/test_main.py .                                  [ 20%]
tests/test_users.py ....                              [100%]

===================== 5 passed in 1.47s =====================

Final words: My final words about mocking in our case would be that it is not quite elegant. Because at the end of the day we are test our APIs, and we can’t just mock calls to API. If we do so, what is the work for testing here? I just demonstrated this method for sake of knowledge. Please go ahead and revert all the changes to test_users.py till now.

If you are more interested in mocking, have a look at this post: https://semaphoreci.com/community/tutorials/mocks-and-monkeypatching-in-python

But we’ll next see more professional way to test our code here.

Overriding dependencies aka Dependency Injection

I have written an entire post about dependency injection which is yet to be posted. In that post I talk how Object Oriented Programming and Dependency Injection are in the same boat.

To give you a visual cue of what dependency injection looks like, here I present you my hand crafted diagram.

Visual cue of Dependency Injection
Visual cue of Dependency Injection

In above diagram, FakeEmailGateway can be used to do testing related invocation without actually sending real emails.

If you a little bit of object oriented programming, you might know how it works. But in this post, I’ll not go deep into object oriented programming here. Please check out my other post titled OOPs and Mocking Meet at Dependency Injection.

Look closely at our database dependency

Do you remember the signature of our register_user? If not, it looks something like this:

1
def register_user(user: schemas.UserCreate, db: Session = Depends(get_db)):

What’s so unique in this? Hint: Flask does not comes with this.

Yeah, you saw it correct. Second parameter to our route handler is db. And instead of directly calling get_db which talks to the production database, we have wrapped it with something called Depends. That’s an elegant way of handling our dependency. That’s the beauty of FastAPI.

Let’s closely see what purpose this Depends serves in FastAPI.

Looking at the definition of get_db, we can see that what this function provides it a db session.

1
2
3
4
5
6
7
8
9
from fastauth.database import SessionLocal

def get_db():
    db = None
    try:
        db = SessionLocal()
        yield db
    finally:
        db.close()

What if we can somehow modify this function to use a temporary SQLite database for testing? With SQLite database, we won’t need a dedicated database running. And we can easily create and delete database files as we need. And thinking about the CI/CD server where this test will be run. We can include a stage to cleanup dangling database files if created.

Overriding our database dependency

Say hello to dependency_overrides. This is a dictionary whose key is a dependency e.g. the one we used above, and then the key is what we want to override this dependency with. So if your FastAPI instance is named app, the whole thing would look like this:

1
app.dependency_overrides[get_db] = get_test_db

Here the get_test_db is supposed to return a database session which uses a SQLite bind.

Since we have two get_db function, and both having overlapping functionalities. I have went ahead and refactored the code a bit. My database.py looks like this right now.

fastauth/database.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
+from abc import abstractmethod
+
 from sqlalchemy import create_engine
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import sessionmaker
 
-SQLALCHEMY_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost/fastauth"
+from fastauth import models
 
-engine = create_engine(SQLALCHEMY_DATABASE_URL)
+Base = declarative_base()
 
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
 
-Base = declarative_base()
+class BaseDBInit:
+    def __init__(self, db_uri) -> None:
+        self.db_uri = db_uri
+        self.engine = None
+        self.create_engine()
+        models.Base.metadata.create_all(bind=self.engine)
+
+    @abstractmethod
+    def create_engine(self):
+        pass
+
+    def get_session(self):
+        session = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
+        return session
+
+
+class DBInitTest(BaseDBInit):
+    def create_engine(self):
+        self.engine = create_engine(
+            self.db_uri, connect_args={"check_same_thread": False}
+        )
+
+
+class DBInit(BaseDBInit):
+    def create_engine(self):
+        self.engine = create_engine(self.db_uri)

Instead of creating engine and session in a static manner, we have added a dynamic behavior to them. We have gone object oriented here and leveraging inheriting power of Python. Explanation below:

  1. Given that there are two consumer of our session object, we have created a base class BaseDBInit to hold the code that is common to them.
  2. What is common to them? Let’s see __init__ dunder method for the answer. We are sticking db_uri to the instance. We are calling create_engine. And we are creating all the models in the database. Not matter what, this chunk of code will be run by both of the sub classes.
  3. We just said in above point that all the deriving classes will call create_engine method. But in BaseDBInit.create_engine there is no definition? Why? Because with abstractmethod, we are indicating that this function needs to be implemented in the deriving classes.
  4. The method with sessionmaker is pretty straighforward. We are doing it the old way, but we are doing it inside the method of BaseDBInit. We are using the engine which is attached to instance (self.engine).
  5. Next we are defining DBInitTest and DBInit which are derived from our BaseDBInit. In this class, we are defining create_engine which is specific to for testing (SQLite) and production (Postgres).

Next, let’s now see how both DBInitTest and DBInit are being used in tests and production.

main.py

We have modified how we connect to database in the main.py file. Let’s see what:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
 from fastapi import FastAPI, Depends, HTTPException
 from sqlalchemy.orm import Session
 
-from fastauth import models, schemas, crud
-from fastauth.database import engine, SessionLocal
+from fastauth import schemas, crud
+from fastauth.database import DBInit
 
-models.Base.metadata.create_all(bind=engine)
 
 def get_db():
-    db = None
+    session = None
     try:
-        db = SessionLocal()
-        yield db
+        session = DBInit("postgresql+psycopg2://postgres:postgres@localhost/fastauth").get_session()
+        session = session()
+        yield session
     finally:
-        db.close()
+        session.close()
 
 app = FastAPI()

And let’s see the user test now.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
-from main import app
+from main import app, get_db
+from fastauth.database import DBInitTest
+
+
+def get_test_db():
+    session = None
+    try:
+        session = DBInitTest("sqlite:///./test.db").get_session()
+        session = session()
+        yield session
+    finally:
+        session.close()
+
+
+app.dependency_overrides[get_db] = get_test_db
+

Other than get_test_db which we are using to get session from a SQLite database, we also need get_db function, to tell our FastAPI instance that ‘hey! we are overriding this particular function’. This thing is done on the last line above.

We also might need to refactor our testing code more as we add more and more test.

If you want to know more about testing with dependencies in FastAP, please have a look at these resources:

Why is test still red after injecting dependency?

Let’s look at how our tests do.

$ pytest -q
.....                                                                     [100%]
5 passed in 1.80s

$ pytest -q
[...]

1 failed, 4 passed in 1.67s

It passes, then it fails. The same thing is happening again. Are we back to ground zero? Why we did all so much work if the result has to be all the same?

I feel your pain. Let me explain why. You have to believe that we made a progress and not stuck at the same stage.

Our test_post_request_with_proper_body_returns_201 is failing and that is correct behavior according to me. If a database has a row, which shouldn’t be there, it will fail. If we look at the Cleanup section of Anatomy of a test, it says that one test should not affect other test. The solution is to remove this row when this test is run.

We can use something called fixtures (or setUp/tearDown stuff if you are from xUnit background)

Note: You might have one question in your mind right now. i.e. Why did we created so much classes and did so much refactoring when we could have done this with postgres itself. The answer to this is simple. First, we use a file based database which we can delete it any time. Secondly, as you might have guessed, it’s easy to deal with files on CD servers than with database servers.

Using fixtures to delete sqlite as needed

We all now know that deleting the sqlite database file before test_post_request_with_proper_body_returns_201 is the way to go. We are not going to set setup for test in this fixture. If you read first few paragraphs of Fixtures How-to, you’ll get a basic understanding of it.

Our fixture is simple, it’s a function which will delete the database file. And because it’s a fixture, we’ll wrap it with pytest.fixture decorator.

1
2
3
@pytest.fixture
def delete_database():
    os.unlink("./test.db")

And if you want to use a fixture, you gotta use that fixture name (function name) as an argument to the test function.

1
2
-    def test_post_request_with_proper_body_returns_201(self):
+    def test_post_request_with_proper_body_returns_201(self, delete_database):

That’s all you need to get going. Make sure all your imports are there. Let’s see how are tests are doing.

$ pytest --no-header --no-summary
========================= test session starts ==========================
collected 5 items                                                      

tests/test_main.py .                                             [ 20%]
tests/test_users.py ....                                         [100%]
========================== 5 passed in 1.74s ===========================

$ pytest --no-header --no-summary
========================= test session starts ==========================
collected 5 items                                                      

tests/test_main.py .                                             [ 20%]
tests/test_users.py ....                                         [100%]

========================== 5 passed in 1.73s ===========================

$ pytest --no-header --no-summary
========================= test session starts ==========================
collected 5 items                                                      

tests/test_main.py .                                             [ 20%]
tests/test_users.py ....                                         [100%]

========================== 5 passed in 1.73s ===========================

No matter how many times I run my code, it always works as expected.

Conclusion

Which method is the best?

In this part of the series we mostly discussed two approches of dealing with our failing tests. One is the mocking, also known as fakes which is used to fake a certain chunk of code. Another one is fixtures, also known as set up and tear down. For our current scenario, I have chosen to go with fixtures method for now.

Thinking about the future, I plan to split my testing suite into two parts in future. One is unit test, which tests only the part of application which does not directly deals with database. Another suite will be for integration testing. I’ll use SQLite to test scenario of registering same user twice.

If you want to delve more into mocking example and practices in Python. I would suggest to give https://realpython.com/python-mock-library/ a read.

Share on

Santosh Kumar
WRITTEN BY
Santosh Kumar
Santosh is a Software Developer currently working with Method Studios as a Full Stack Developer.