Note. I’ve received a complaint that the site here has performance issues. Therefore, here’s a spare one hosted at my own blog. https://blog.plummmm.com/posts/repr-int-sum-eight-squares.html

Introduction

We’re interested in the following computational problem: Given $n \in \mathbb{N}$, how many ways (i.e. number of tuples $(x_1, \dots, x_8) \in \mathbb{Z}^8$) are there such that $n$ can be written as $n = x_1^2 + \dots + x_8^2$.

Our basic computer programming knowledge tells us that we can design a naive algorithm as follows:

n = int(input())
ans = 0
for x1 in range(-n, n+1):
	for x2 in range(-n, n+1):
		for x3 in range(-n, n+1):
			for x4 in range(-n, n+1):
				for x5 in range(-n, n+1):
					for x6 in range(-n, n+1):
						for x7 in range(-n, n+1):
							for x8 in range(-n, n+1):
								if x1**2 + x2**2 + x3**2 + x4**2 + x5**2 + x6**2 + x7**2 + x8**2 == n:
									ans += 1
print(ans)

Of course, this takes a time complexity of $O(n^8)$, which is not very good. A basic obvious improvement is to observe that $x^2 \geq 0$ for all $x \in \mathbb{N}$, so in order for a tuple to work, no entry can be larger than $\sqrt{n}$ because otherwise (say, if $x_1 > \sqrt{n}$) we would have $x_1^2 + \dots + x_8^2 > n+x_2^2 + \dots + x_8^2 \geq n+1$. This allows us to decrease the loop to the size of $O(\sqrt{n})$.

n = int(input())
ans = 0
k = int(n**0.5)+1
for x1 in range(-k, k+1):
	for x2 in range(-k, k+1):
		for x3 in range(-k, k+1):
			for x4 in range(-k, k+1):
				for x5 in range(-k, k+1):
					for x6 in range(-k, k+1):
						for x7 in range(-k, k+1):
							for x8 in range(-k, k+1):
								if x1**2 + x2**2 + x3**2 + x4**2 + x5**2 + x6**2 + x7**2 + x8**2 == n:
									ans += 1
print(ans)

This takes time $O(n^4)$, a bit better, huh?

Obviously, from the eye of an algorithmist, we can do better. One direct way when we see something in this form is to do the meet-in-the-middle trick: instead of trying to solve for $x_1^2 + \dots + x_8^2 = n$, we solve for $x_1^2 + x_2^2 + x_3^2 + x_4^2 = m$ and $x_5^2 + x_6^2 + x_7^2 + x_8^2 = n-m$, count the number of solutions to this, and sum over the scenarios for each $m \in \{0, \dots, n\}$. Let us try to implement.

n = int(input())

def count_sol(m):
	k = int(m**0.5)+1
	local_ans = 0
	for a in range(-k, k+1):
		for b in range(-k, k+1):
			for c in range(-k, k+1):
				for d in range(-k, k+1):
					if a**2 + b**2 + c**2 + d**2 == m:
						local_ans += 1
	return local_ans

ans = 0
for m in range(n+1):
	ans += count_sol(m) * count_sol(n-m)
print(ans)

The count_sol subroutine runs in time $O(\sqrt{m}^4) = O(m^2)$. We repeat the process $2(n+1)$ times, each of $m \leq n$ so the total time is in $2(n+1) \cdot O(n^2) = O(n^3)$.

Can we do better? Still yes. Skilled algorithmists might even come up of this formulation immediately (note: I’m not that good though — I asked this problem to a bunch of friends and they immediately gave me the following formulation). We formulate this in terms of knapsack problem. Define $\mathrm{dp}{i,j}$ to be the quantity $\#\{(x_1, \dots, x_i) \in \mathbb{Z}^i \colon \sum{k=1}^i x_k^2 = j\}$, and look at what we can do about its recurrence.

In order to know the quantity $\mathrm{dp}_{i,j}$, we can enumerate all possible $x_i$’s and observe that

$$ \{(x_1, \dots, x_i) \in \mathbb{Z}^i \colon \sum_{k=1}^i x_k^2 = j\} \\= \bigcup_{0 \leq x_i \leq \sqrt{j}} \{(x_1, \dots, x_{i-1}) \in \mathbb{Z}^{i-1} \colon \sum_{k=1}^{i} x_k^2 = j\} \times \{(x_i), (-x_i)\} \\ = \bigcup_{0 \leq x_i \leq \sqrt{j}} \{(x_1, \dots, x_{i-1}) \in \mathbb{Z}^{i-1} \colon \sum_{k=1}^{i-1} x_k^2 = j-x_i^2\} \times \{(x_i), (-x_i)\}. $$

Observe that the union in disjoint, so the cardinality of the whole is the sum of the cardinality of each case, so

$$ \mathrm{dp}{i,j} = \sum{1 \leq a \leq \sqrt{j}} 2\mathrm{dp}{i-1, j-a^2} + \mathrm{dp}{i-1,j}. $$

With this formula, we can implement our dynamic programming algorithm as follows.

n = int(input())
dp = [[0 for j in range(n+1)] for i in range(9)]
dp[0][0] = 1
for i in range(1, 9):
	for j in range(0, n+1):
		dp[i][j] += dp[i-1][j]
		for a in range(1, int(j**0.5)+1):
			if j >= a**2:
				dp[i][j] += 2*dp[i-1][j-a**2]
print(dp[8][n])

This takes time $O(n^{3/2})$, which is already very much improved from before. Is this the end? Can we do better? An algorithmist might give up at this point, but I refuse to give up just yet! Can we use some mathematical magic to improve this even more? It turns out that we can!

In this article, I’m going to invite you to the fascinating (quite relatively deep) theory of modular forms, not as just an algorithmist, but more as an algorithmic number theorist.