Mocking Spring Data repositories

Unit testing should be a substantial part of the work of any developer. If you are using Spring Data, maybe in conjunction with Spring Boot, the time it takes to boot up a test environment can be quite substantial. For simple component tests people skip setting up a complete environment and use mocks for classes that are unavailable at the “bare” test runtime. In today’s post I like to share a little code snippet on how to mock a CRUD repository from Spring Data, that actually has basic database functionality. I use a ConcurrentHashMap to keep data in memory. “How to stub the spring data methods” you ask? The answer is Answers.

package pls.gooby;

import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.beans.BeanUtils;
import org.springframework.data.repository.CrudRepository;

import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

/**
 * Mocking spring data CRUD repositories for testing.
 */
public class MockSpringData {

    /**
     * Assumptions: E is an entity with UUID as primary key, which has methods getId and setId for it.
     */
    public static <E, R extends CrudRepository<E, String>> R mockCrudRepo(final Class<E> entityClass, Class<R> repositoryClass) {
        R mockedRepository = mock(repositoryClass);
        final ConcurrentHashMap<String, E> persistenceMap = new ConcurrentHashMap<>();
        doAnswer(new Answer() {
            @Override
            public Object answer(InvocationOnMock invocation) throws Throwable {
                Object argument = invocation.getArguments()[0];
                String id = (String) entityClass.getMethod("getId").invoke(argument);
                E persistedItem = id == null ? null : persistenceMap.get(id);
                if (persistedItem == null) {
                    String uuid = UUID.randomUUID().toString();
                    entityClass.getMethod("setId", String.class).invoke(argument, uuid);
                    persistenceMap.put(uuid, (E) argument);
                    return argument;
                } else {
                    BeanUtils.copyProperties(argument, persistedItem);
                    return persistedItem;
                }
            }
        }).when(mockedRepository).save(any(entityClass));

        doAnswer(new Answer() {
            @Override
            public Object answer(InvocationOnMock invocation) throws Throwable {
                return persistenceMap.get(invocation.getArguments()[0]);
            }
        }).when(mockedRepository).findOne(anyString());

        doAnswer(new Answer() {
            @Override
            public Object answer(InvocationOnMock invocation) throws Throwable {
                return persistenceMap.remove(invocation.getArguments()[0]);
            }
        }).when(mockedRepository).delete(anyString());

        doAnswer(new Answer() {
            @Override
            public Object answer(InvocationOnMock invocation) throws Throwable {
                return persistenceMap.values();
            }
        }).when(mockedRepository).findAll();

        doAnswer(new Answer() {
            @Override
            public Object answer(InvocationOnMock invocation) throws Throwable {
                return persistenceMap.remove(entityClass.getMethod("getId").invoke(invocation.getArguments()[0]));
            }
        }).when(mockedRepository).delete(any(entityClass));

        doAnswer(new Answer() {
            @Override
            public Object answer(InvocationOnMock invocation) throws Throwable {
                for (E item : (Iterable<E>) invocation.getArguments()[0]) {
                    persistenceMap.remove(entityClass.getMethod("getId").invoke(item));
                }
                return null;
            }
        }).when(mockedRepository).delete(any(Iterable.class));

        return mockedRepository;
    }
}

A mock can simply be created by

MyCrudRepo repo = pls.gooby.MockSpringData.mockCrudRepo(MyEntity.class, MyCrudRepo.class);

Final note: the code for mocking can still be simplified to some point using lambda expressions. Provided you are allowed to use Java 8 of course:

package pls.gooby;

import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.beans.BeanUtils;
import org.springframework.data.repository.CrudRepository;

import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

/**
 * Mocking spring data CRUD repositories for testing.
 */
public class MockSpringData {

    /**
     * Assumptions: E is an entity with UUID as primary key, which has methods getId and setId for it.
     */
    public static <E, R extends CrudRepository<E, String>> R mockCrudRepo(final Class<E> entityClass, Class<R> repositoryClass) {
        R mockedRepository = mock(repositoryClass);
        final ConcurrentHashMap<String, E> persistenceMap = new ConcurrentHashMap<>();
        doAnswer(invocation -> {
            Object argument = invocation.getArguments()[0];
            String id = (String) entityClass.getMethod("getId").invoke(argument);
            E persistedItem = id == null ? null : persistenceMap.get(id);
            if (persistedItem == null) {
                String uuid = UUID.randomUUID().toString();
                entityClass.getMethod("setId", String.class).invoke(argument, uuid);
                persistenceMap.put(uuid, (E) argument);
                return argument;
            } else {
                BeanUtils.copyProperties(argument, persistedItem);
                return persistedItem;
            }
        }).when(mockedRepository).save(any(entityClass));

        doAnswer(invocation -> persistenceMap.get(invocation.getArguments()[0])).when(mockedRepository).findOne(anyString());

        doAnswer(invocation -> persistenceMap.remove(invocation.getArguments()[0])).when(mockedRepository).delete(anyString());

        doAnswer(invocation -> persistenceMap.values()).when(mockedRepository).findAll();

        doAnswer(invocation -> persistenceMap.remove(entityClass.getMethod("getId").invoke(invocation.getArguments()[0]))).when(mockedRepository).delete(any(entityClass));

        doAnswer(invocation -> {
            for (E item : (Iterable<E>) invocation.getArguments()[0]) {
                persistenceMap.remove(entityClass.getMethod("getId").invoke(item));
            }
            return null;
        }).when(mockedRepository).delete(any(Iterable.class));

        return mockedRepository;
    }
}
Advertisements

About goobypl5

pizza baker, autodidact, particle physicist
This entry was posted in Programming and tagged , , , , , , . Bookmark the permalink.

One Response to Mocking Spring Data repositories

  1. André says:

    Thank you for sharing!

Share your thoughts

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s