Unit Test

TDD(실전! 스프링부트 상품-주문 API 개발로 알아보는 TDD)

임요환 2023. 3. 21. 23:58

TDD(Test Driven Development)

선 테스트 후 개발방식의 프로그래밍 자동화된 테스트 코드를 작성한 후 테스트를 통과하기 위한 최소한의 코드를 개발하는 방식

  1. 테스트 케이스 작성
  2. 테스트 케이스를 통과하는 코드 작성
  3. 작성한 코드 리팩터링

왜 RestAssured를 사용했을까?

  • TestRestTemplate는 spring test에 포함되어 있음
    • 커뮤니티 글을 찾아보면 거의 동일한 역할을 한다고 나와있고 TestRestTemplate은 spring에서 제공
  • 결국 스프링 시큐리티를 포함해서 테스트를 하려면 MockMvc로 바꿔야 할까?

테스트 진행 방식

1. POJO 방식의 테스트 구현

public class ProductServiceTest {
    private ProductService productService;
    private ProductPort productPort;
    private ProductRepository productRepository;
    
    @BeforeEach
    void setUp() {
        productRepository = new ProductRepository();
        productPort = new ProductAdapter(productRepository);
        productService = new ProductService(productPort);
    }

    @Test
    void 상품등록(){
        final String name = "상품명";
        final int price = 1000;
        final DiscountPolicy discountPolicy = DiscountPolicy.NONE;
        final AddProductRequest request = new AddProductRequest(name, price, discountPolicy);

        productService.addProduct(request);
    }

    private class ProductService{
        private final ProductPort productPort;
        public ProductService(ProductPort productPort){
            this.productPort = productPort;
        }
        public void addProduct(final AddProductRequest request){
            final Product product = new Product(request.getName(), request.getPrice(), request.getDiscountPolicy());
            productPort.save(product);
        }
    }

    private class AddProductRequest {
        private final String name;
        private final int price;
        private final DiscountPolicy discountPolicy;

        public AddProductRequest(final String name, final int price, final DiscountPolicy discountPolicy){
            Assert.hasText(name, "상품명은 필수입니다.");
            Assert.isTrue(price > 0, "상품 가격은 0보다 커야합니다.");
            Assert.notNull(discountPolicy, "할인 정책은 필수 입니다.");
            this.name = name;
            this.price = price;
            this.discountPolicy = discountPolicy;
        }

        public String getName() {
            return name;
        }

        public int getPrice() {
            return price;
        }

        public DiscountPolicy getDiscountPolicy() {
            return discountPolicy;
        }
    }

    private enum DiscountPolicy {
        NONE
    }

    private class Product {
        private Long id;
        private String name;
        private int price;
        private DiscountPolicy discountPolicy;
        public Product(final String name, final int price, final DiscountPolicy discountPolicy){
            Assert.hasText(name, "상품명은 필수입니다.");
            Assert.isTrue(price > 0, "상품 가격은 0보다 커야합니다.");
            Assert.notNull(discountPolicy, "할인 정책은 필수 입니다.");
            this.name = name;
            this.price = price;
            this.discountPolicy = discountPolicy;
        }

        public String getName() {
            return name;
        }

        public int getPrice() {
            return price;
        }

        public DiscountPolicy getDiscountPolicy() {
            return discountPolicy;
        }

        public Long getId() {
            return id;
        }

        public void assignId(final Long id) {
            this.id = id;
        }
    }

    private interface ProductPort {
        void save(final Product product);
    }

    private class ProductAdapter implements ProductPort{
        private ProductRepository productRepository;

        private ProductAdapter(ProductRepository productRepository) {
            this.productRepository = productRepository;
        }

        @Override
        public void save(Product product) {
            productRepository.save(product);
        }
    }

    private class ProductRepository {
        private Map<Long, Product> persistence = new HashMap<>();
        private Long sequence = 0L;
        public void save(final Product product){
            product.assignId(++sequence);
            persistence.put(product.getId(), product);
        }
    }
}

2. 테스트 성공 확인 후 inner class 들을 전부 하나의 클래스파일로 모두 이동시킨다.

public class ProductServiceTest {
    private ProductService productService;
    private ProductPort productPort;
    private ProductRepository productRepository;
    @BeforeEach
    void setUp() {
        productRepository = new ProductRepository();
        productPort = new ProductAdapter(productRepository);
        productService = new ProductService(productPort);
    }

    @Test
    void 상품등록(){
        final AddProductRequest request = 상품등록요청_생성();

        productService.addProduct(request);
    }

    private static AddProductRequest 상품등록요청_생성() {
        final String name = "상품명";
        final int price = 1000;
        final DiscountPolicy discountPolicy = DiscountPolicy.NONE;
        final AddProductRequest request = new AddProductRequest(name, price, discountPolicy);
        return request;
    }
}

3. @SpringBootTest로 전환(각각의 클래스를 스프링빈으로 만드는 작업은 생략하였습니다)

@SpringBootTest
class ProductServiceTest {

    @Autowired
    private ProductService productService;

    @Test
    void 상품등록(){
        final AddProductRequest request = 상품등록요청_생성();

        productService.addProduct(request);
    }

    private static AddProductRequest 상품등록요청_생성() {
        final String name = "상품명";
        final int price = 1000;
        final DiscountPolicy discountPolicy = DiscountPolicy.NONE;
        final AddProductRequest request = new AddProductRequest(name, price, discountPolicy);
        return request;
    }

}

4. API Test로 변경

  • build.gradle에 restAssured추가
testImplementation 'io.rest-assured:rest-assured:4.4.0'
  • restAssured 용 클래스 생성(상속받아서 사용)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class ApiTest {

    @LocalServerPort
    private int port;

    @BeforeEach
    void setUp(){
    	RestAssured.port = port;    
    }
}
  • 적용(ProductService를 컨트롤러로 만드는 부분은 생략)
class ProductApiTest extends ApiTest {
    @Test
    void 상품등록(){
        final AddProductRequest request = 상품등록요청_생성();

        final ExtractableResponse<Response> response = RestAssured.given().log().all()
                .contentType(MediaType.APPLICATION_JSON_VALUE)
                .body(request)
                .when()
                .post("/products")
                .then()
                .log().all().extract();

        assertThat(response.statusCode()).isEqualTo(HttpStatus.CREATED.value());
    }
    
    private static AddProductRequest 상품등록요청_생성() {
        final String name = "상품명";
        final int price = 1000;
        final DiscountPolicy discountPolicy = DiscountPolicy.NONE;
        final AddProductRequest request = new AddProductRequest(name, price, discountPolicy);
        return request;
    }
}

5. 메모리 repository -> JPA Repository로 수정(클래스에 하는 거는 생략), RestAssured 캐시 방지를 위한 데이터베이스 클린 클래스 추가

  • DatabaseCleanup 클래스 추가
import com.google.common.base.CaseFormat;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;

import javax.persistence.Entity;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Table;
import javax.persistence.metamodel.EntityType;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

@Component
public class DatabaseCleanup implements InitializingBean {
    @PersistenceContext
    private EntityManager entityManager;

    private List<String> tableNames;

    @Override
    public void afterPropertiesSet() {
        final Set<EntityType<?>> entities = entityManager.getMetamodel().getEntities();
        tableNames = entities.stream()
                .filter(e -> isEntity(e) && hasTableAnnotation(e))
                .map(e -> {
                    String tableName = e.getJavaType().getAnnotation(Table.class).name();
                    return tableName.isBlank() ? CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, e.getName()) : tableName;
                })
                .collect(Collectors.toList());

        final List<String> entityNames = entities.stream()
                .filter(e -> isEntity(e) && !hasTableAnnotation(e))
                .map(e -> CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, e.getName()))
                .collect(Collectors.toList());

        tableNames.addAll(entityNames);
    }

    private boolean isEntity(final EntityType<?> e) {
        return null != e.getJavaType().getAnnotation(Entity.class);
    }

    private boolean hasTableAnnotation(final EntityType<?> e) {
        return null != e.getJavaType().getAnnotation(Table.class);
    }

    @Transactional
    public void execute() {
        entityManager.flush();
        entityManager.createNativeQuery("SET REFERENTIAL_INTEGRITY FALSE").executeUpdate();

        for (final String tableName : tableNames) {
            entityManager.createNativeQuery("TRUNCATE TABLE " + tableName).executeUpdate();
            entityManager.createNativeQuery("ALTER TABLE " + tableName + " ALTER COLUMN ID RESTART WITH 1").executeUpdate();
        }

        entityManager.createNativeQuery("SET REFERENTIAL_INTEGRITY TRUE").executeUpdate();
    }
}
  • DatabaseCleanup을 위한 의존성 추가
implementation 'com.google.guava:guava:31.1-jre'
  • ApiTest 클래스 수정
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class ApiTest {

    @Autowired
    private DatabaseCleanup databaseCleanup;
    @LocalServerPort
    private int port;

    @BeforeEach
    void setUp(){
        if(RestAssured.port == RestAssured.UNDEFINED_PORT){
            RestAssured.port = port;
            databaseCleanup.afterPropertiesSet();
        }
        databaseCleanup.execute();
    }
}

 

https://github.com/limyohwan/tdd-introduction