bong-u/til

Fastapi - 통합테스트 In-Memory DB에서 테이블이 없다는 문제

수정일 : 2024-07-22

상황

  • 테커 부트캠프에서 팀프로젝트를 진행 중이다.
  • 단위테스트 코드는 작성이 완료되었고, 통합테스트 코드를 작성 중이다.
  • sqlite in-memory db를 사용해서 테스트 중인데, 테이블이 없다는 에러가 발생했다.
  • 테스트 전에 테이블을 생성하는 코드가 실행됨에도 불구하고, 에러가 발생한다.
  • 인메모리가 아닌 파일로 저장하는 방법을 사용하면 에러가 발생하지 않는 것을 보고 문제의 원인을 파악할 수 있었다.

코드

 1from database import Base, engine
 2from fastapi.testclient import TestClient
 3
 4from main import app
 5from models import *
 6
 7# 테이블을 생성하는 코드이다
 8Base.metadata.create_all(bind=engine)
 9
10client = TestClient(app)
11
12
13class TestUserApi:
14
15    def test_create_user(self):
16        test_nickname = "test_nickname"
17        # 아래 요청을 처리하는 코드에서 오류가 발생한다
18        response = client.post(
19            "/api/users",
20            json={"nickname": test_nickname},
21        )
22        assert response.status_code == 200
23        assert response.json()["nickname"] == test_nickname
24

원인

  • 테이블을 생성할때 만들어지는 세션과 TestClient가 요청을 처리할 때 사용하는 세션이 다르다.

해결 방법

  • TestClient내에 get_db() 함수를 임의로 주입한다
  • 데이터베이스를 연결할때, 단일 세션을 사용하도록 한다.
 1from database import Base, engine, get_db
 2from sqlalchemy.orm import sessionmaker
 3from fastapi.testclient import TestClient
 4
 5from main import app
 6from models import *
 7
 8Base.metadata.create_all(bind=engine)
 9
10client = TestClient(app)
11
12# 테스트에서 사용할 세션을 생성한다
13TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
14
15
16Base.metadata.create_all(bind=engine)
17
18# get_db() 함수를 재정의한다
19def override_get_db():
20    try:
21        db = TestingSessionLocal()
22        yield db
23    finally:
24        db.close()
25
26# get_db() 함수를 재정의한 함수를 주입한다
27app.dependency_overrides[get_db] = override_get_db
28
29
30class TestUserApi:
31
32    def test_create_user(self):
33        test_nickname = "test_nickname"
34        response = client.post(
35            "/api/users",
36            json={"nickname": test_nickname},
37        )
38        assert response.status_code == 201
39        assert response.json()["nickname"] == test_nickname
1engine = create_engine(    
2    os.getenv("DATABASE_URL"),
3    # sqlite를 사용할 때, 여러 스레드에서 연결이 가능하도록 설정한다
4    connect_args={"check_same_thread": False},
5    # 단일 세션을 사용하도록 설정한다
6    poolclass=StaticPool,
7)