ll solve(ll n,ll k){ ll sum = 0; if (n > k){ sum += k*(n-k); n = k; } int ndiv = k / n; int nnext,low,high; while (n > 1){ nnext = k / (ndiv + 1); if (n == nnext){ sum += k % n; n--; ndiv = k / n; continue; } low = k % n; high = k % (nnext + 1); sum += ((low + high) * (n - nnext)) >> 1; n = nnext; ndiv++; } return sum; }