Integration Tests for N + 1 problem in Java by Semyon Kirekov

Semyon Kirekov
May 21, 2023, 11:30 AM

N + 1 problem is a common issue in many enterprise projects. The worst is that you don't notice it until the amount of data becomes huge. Unfortunately, the code might reach the stage when dealing with N + 1 problem becomes an unbearable task.

In this article, I'm telling you:

  1. How to track N + 1 problem automatically?
  2. How to write a test to check that the query count does not exceed the expected value?

The tech stack consists of Java, Spring Boot, Spring Data JPA, and PostgreSQL. You can check out the repository with code examples by this link.

There are no restrictions to apply Spring Boot or Hibernate specifically. If you interact with javax.sql.DataSource in your codebase, then the solution will help you. Even if you don't use Spring at all.

Meme article cover

An example of N + 1 problem

Supposing we are working on the application that manages zoos. In that case, there are two core entities: Zoo and Animal. Look at the code snippet below:

@Entity
@Table(name = "zoo")
public class Zoo {
    @Id
    @GeneratedValue(strategy = IDENTITY)
    private Long id;

    private String name;

    @OneToMany(mappedBy = "zoo", cascade = PERSIST)
    private List<Animal> animals = new ArrayList<>();
}

@Entity
@Table(name = "animal")
public class Animal {
    @Id
    @GeneratedValue(strategy = IDENTITY)
    private Long id;

    @ManyToOne(fetch = LAZY)
    @JoinColumn(name = "zoo_id")
    private Zoo zoo;

    private String name;
}
Enter fullscreen mode Exit fullscreen mode

Now we want to retrieve all existing zoos with their animals. Look at the code of ZooService below.

@Service
@RequiredArgsConstructor
public class ZooService {
    private final ZooRepository zooRepository;

    @Transactional(readOnly = true)
    public List<ZooResponse> findAllZoos() {
        final var zoos = zooRepository.findAll();
        return zoos.stream()
                   .map(ZooResponse::new)
                   .toList();
    }
}
Enter fullscreen mode Exit fullscreen mode

Also, we want to check that everything works smoothly. So, here is a simple integration test:

@DataJpaTest
@AutoConfigureTestDatabase(replace = NONE)
@Transactional(propagation = NOT_SUPPORTED)
@Testcontainers
@Import(ZooService.class)
class ZooServiceTest {
    @Container
    static final PostgreSQLContainer<?> POSTGRES = new PostgreSQLContainer<>("postgres:13");

    @DynamicPropertySource
    static void setProperties(DynamicPropertyRegistry registry) {
        registry.add("spring.datasource.url", POSTGRES::getJdbcUrl);
        registry.add("spring.datasource.username", POSTGRES::getUsername);
        registry.add("spring.datasource.password", POSTGRES::getPassword);
    }

    @Autowired
    private ZooService zooService;
    @Autowired
    private ZooRepository zooRepository;

    @Test
    void shouldReturnAllZoos() {
        /* data initialization... */
        zooRepository.saveAll(List.of(zoo1, zoo2));

        final var allZoos = assertQueryCount(
            () -> zooService.findAllZoos(),
            ofSelects(1)
        );

        /* assertions... */
        assertThat(
            ...
        );
    }
}
Enter fullscreen mode Exit fullscreen mode

I skipped the data initialization and assertions parts for the sake of simplicity. They are not important for the article's topic. Anyway, you can check out the whole test suite by this link.

I have a particular piece about testing data layer in Spring Boot application with Testcontainers. If you're unfamiliar with the topic, you should definitely look through it.

The test passes successfully. However, if you log SQL statements, you'll notice something that may concern you. Look at the output below:

-- selecting all zoos
select z1_0.id,z1_0.name from zoo z1_0
-- selecting animals for the first zoo
select a1_0.zoo_id,a1_0.id,a1_0.name from animal a1_0 where a1_0.zoo_id=?
-- selecting animals for the second zoo
select a1_0.zoo_id,a1_0.id,a1_0.name from animal a1_0 where a1_0.zoo_id=?
Enter fullscreen mode Exit fullscreen mode

As you can see, we have a separate select query for each present Zoo. The total number of queries equals to the number selected zoos + 1. Therefore, this is the N + 1 problem.

This may cause crucial performance penalties. Especially on a large scale of data.

Tracking the N + 1 problem automatically

Of course, you can run tests, look through the logs, and count queries by yourself to determine viable performance issues. Anyway, this is both tedious and inefficient. Thankfully, there is a better approach.

There is a cool library called datasource-proxy. It provides a convenient API to wrap javax.sql.DataSource interface with a proxy containing specific logic. For example, we can register callbacks invoked before and after query execution. What's interesting is that the library also contains out-of-the-box solution to count executed queries. We're going to alter it a bit to serve our needs.

Query Count Service

Firstly, add the library to the dependencies:

implementation "net.ttddyy:datasource-proxy:1.8"
Enter fullscreen mode Exit fullscreen mode

Now create the QueryCountService. It’s the singleton that holds the current count of executed queries and allows to clean it. Look at the code snippet below.

@UtilityClass
public class QueryCountService {
    static final SingleQueryCountHolder QUERY_COUNT_HOLDER = new SingleQueryCountHolder();

    public static void clear() {
        final var map = QUERY_COUNT_HOLDER.getQueryCountMap();
        map.putIfAbsent(keyName(map), new QueryCount());
    }

    public static QueryCount get() {
        final var map = QUERY_COUNT_HOLDER.getQueryCountMap();
        return ofNullable(map.get(keyName(map))).orElseThrow();
    }

    private static String keyName(Map<String, QueryCount> map) {
        if (map.size() == 1) {
            return map.entrySet()
                       .stream()
                       .findFirst()
                       .orElseThrow()
                       .getKey();
        }
        throw new IllegalArgumentException("Query counts map should consists of one key: " + map);
    }
}
Enter fullscreen mode Exit fullscreen mode

In that case, we make an assumption there is a single DataSource in our application. That’s why the keyName function throws an exception otherwise. However, the code won’t differ much with multiple data sources usage.

The SingleQueryCountHolder stores all QueryCount objects in a regular ConcurrentHashMap.

On the contrary, ThreadQueryCountHolder stores the values in ThreadLocal object. But SingleQueryCountHolder is enough for our case.

The API provides two methods. The get method returns current amount of executed queries whilst the clear one sets the count to zero.

BeanPostProccessor and DataSource proxy

Now we need to register the QueryCountService to make it collect the data from the DataSource. In that case, the BeanPostProcessor interface comes in handy. Look at the code example below.

@TestComponent
public class DatasourceProxyBeanPostProcessor implements BeanPostProcessor {
    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) {
        if (bean instanceof DataSource dataSource) {
            return ProxyDataSourceBuilder.create(dataSource)
                       .countQuery(QUERY_COUNT_HOLDER)
                       .build();
        }
        return bean;
    }
}
Enter fullscreen mode Exit fullscreen mode

I mark the class with @TestComponent annotation and put to src/test directory because I don't need to count queries outside of the test scope.

As you can see, the idea is trivial. If a bean is DataSource, then wrap it with ProxyDataSourceBuilder and put the QUERY_COUNT_HOLDER value as the QueryCountStrategy.

Finally, we want to assert the amount of executed queries for the specific method. Look at the code snippet with custom assertions below:

Custom assertions

@UtilityClass
public class QueryCountAssertions {
    @SneakyThrows
    public static <T> T assertQueryCount(Supplier<T> supplier, Expectation expectation) {
        QueryCountService.clear();
        final var result = supplier.get();
        final var queryCount = QueryCountService.get();
        assertAll(
            () -> {
                if (expectation.selects >= 0) {
                    assertEquals(expectation.selects, queryCount.getSelect(), "Unexpected selects count");
                }
            },
            () -> {
                if (expectation.inserts >= 0) {
                    assertEquals(expectation.inserts, queryCount.getInsert(), "Unexpected inserts count");
                }
            },
            () -> {
                if (expectation.deletes >= 0) {
                    assertEquals(expectation.deletes, queryCount.getDelete(), "Unexpected deletes count");
                }
            },
            () -> {
                if (expectation.updates >= 0) {
                    assertEquals(expectation.updates, queryCount.getUpdate(), "Unexpected updates count");
                }
            }
        );
        return result;
    }
}
Enter fullscreen mode Exit fullscreen mode

The algorithm is straightforward:

  1. Set the current queries count to zero.
  2. Executed the provided lambda.
  3. Assert the query count to the given Expectation object.
  4. If everything passes successfully, return the result of execution.

Also, you’ve noticed an additional condition. If the provided type of count is less than zero, skip the assertion. It’s convenient, when you don’t care about other queries count.

The Expectation class is just a regular data structure. Look at its declaration below:

@With
@AllArgsConstructor
@NoArgsConstructor
public static class Expectation {
    private int selects = -1;
    private int inserts = -1;
    private int deletes = -1;
    private int updates = -1;

    public static Expectation ofSelects(int selects) {
        return new Expectation().withSelects(selects);
    }

    public static Expectation ofInserts(int inserts) {
        return new Expectation().withInserts(inserts);
    }

    public static Expectation ofDeletes(int deletes) {
        return new Expectation().withDeletes(deletes);
    }

    public static Expectation ofUpdates(int updates) {
        return new Expectation().withUpdates(updates);
    }
}
Enter fullscreen mode Exit fullscreen mode

The final example

Let’s see how it works. At first, I add query assertions in the previous case with N + 1 problem. Look at the code block below:

final var allZoos = assertQueryCount(
    () -> zooService.findAllZoos(),
    ofSelects(1)
);
Enter fullscreen mode Exit fullscreen mode

Don't forget to import DatasourceProxyBeanPostProcessor as a Spring bean in your tests.

If we rerun the test, we'll get the output below.

Multiple Failures (1 failure)
    org.opentest4j.AssertionFailedError: Unexpected selects count ==> expected: <1> but was: <3>
Expected :1
Actual   :3
Enter fullscreen mode Exit fullscreen mode

So, the assertion does work. We managed to track the N + 1 problem automatically. Time to replace the regular select with JOIN FETCH. Look at the code snippet below.

public interface ZooRepository extends JpaRepository<Zoo, Long> {
    @Query("FROM Zoo z LEFT JOIN FETCH z.animals")
    List<Zoo> findAllWithAnimalsJoined();
}

@Service
@RequiredArgsConstructor
public class ZooService {
    private final ZooRepository zooRepository;

    @Transactional(readOnly = true)
    public List<ZooResponse> findAllZoos() {
        final var zoos = zooRepository.findAllWithAnimalsJoined();
        return zoos.stream()
                   .map(ZooResponse::new)
                   .toList();
    }
}
Enter fullscreen mode Exit fullscreen mode

Let's run the test again and check out the result:

Test result with JOIN FETCH

Meaning that the assertion tracks N + 1 problems correctly. Besides, it passes successfully, if the amount of queries equals to the expected one. Great!

Conclusion

As a matter of fact, it is possible to prevent N + 1 problems with regular tests. I think that’s a great opportunity to put guards for those code parts that are crucial to performance perspective.

That’s all I wanted to tell you about dealing with N + 1 problem in the automatic way. If you have any questions or suggestions, leave your comments down below. Also, if you like this piece, share it with your friends and colleagues. Perhaps they’ll find it beneficial too. Thanks for reading!

Resources

  1. The repository with code examples
  2. My article 'Spring Boot Testing — Data and Services'
  3. Testcontainers
  4. Datasource proxy library
  5. The BeanPostProcessor interface example

Latest Posts

    Stay Connected

    Get misiki news delivered straight to your inbox

    © 2024 Misiki Technologies LLP

    All Rights Reserved