All Articles

A Weird Bug Caused by Optimizing Integer Division with Assembly

assembly

Recently, I encountered a weird bug in my code. I was trying to optimize a piece of code with assembly. The code was written in C, and I wanted to optimize the division operation. I thought that using assembly would be a good idea, but it turned out to be a disaster. It took me more than 3 hours to figure out what was going on. I didn’t get significant performance improvement even after I fixed the bug and optimized the code.

The code was simple. As we know, the x86 architecture provides a division instruction that outputs the quotient and the remainder together. But in the C code, we have to write the division statement twice to get the quotient and the remainder. I thought the assembly code could be used to optimize the division operation. The code before optimization was like this:

// prog_purec.c
#include <stdio.h>

// Define the problem scale
#define MAXN 5000004
#define MAXP 350000
#define MAXR 2236

char mk[MAXN];         // Prime bits
int prime[MAXP], pnum; // Primes
int c[32];             // Array for generating answer

void init()
{
    int i, j;
    for (i = 0; i < 32; i++)
        c[i] = (i + 2) * (i + 2) * (i + 1) * (i + 1) >> 2; // Sum of i^3
    for (i = 2; i < MAXN; i++)
        if (!mk[i])
        {
            prime[pnum++] = i;
            if (i < MAXR)
                for (j = i * i; j < MAXN; j += i)
                    mk[j] = 1;
        }
}

int main()
{
    int i, j, ncase, n, res;
    init();
    scanf("%d", &ncase);
    while (ncase--)
    {
        scanf("%d", &n);
        res = 1;
        for (i = 0; i < pnum; i++)
            if (!mk[n])
            {
                if (n > 1)
                    res *= c[1];
                break;
            }
            else
            {
                j = 0;
                while (!(n % prime[i]))
                {
                    n /= prime[i];
                    j++;
                }
                if (j > 0)
                    res *= c[j];
            }
        printf("%d\n", res);
    }
    return 0;
}
.file    "prog_purec.c"
.text
.globl init
.type    init, @function
init:
    pushl    %ebp
    movl    %esp, %ebp
    subl    $16, %esp
    movl    $0, -8(%ebp)
    jmp    .L2
.L3:
    movl    -8(%ebp), %ecx
    movl    -8(%ebp), %edx
    addl    $2, %edx
    movl    -8(%ebp), %eax
    addl    $2, %eax
    imull    %eax, %edx
    movl    -8(%ebp), %eax
    incl    %eax
    imull    %eax, %edx
    movl    -8(%ebp), %eax
    incl    %eax
    imull    %edx, %eax
    sarl    $2, %eax
    movl    %eax, c(,%ecx,4)
    incl    -8(%ebp)
.L2:
    cmpl    $31, -8(%ebp)
    jle    .L3
    movl    $2, -8(%ebp)
    jmp    .L5
.L6:
    movl    -8(%ebp), %eax
    movzbl    mk(%eax), %eax
    testb    %al, %al
    jne    .L7
    movl    pnum, %eax
    movl    -8(%ebp), %edx
    movl    %edx, prime(,%eax,4)
    incl    %eax
    movl    %eax, pnum
    cmpl    $2235, -8(%ebp)
    jg    .L7
    movl    -8(%ebp), %eax
    imull    -8(%ebp), %eax
    movl    %eax, -4(%ebp)
    jmp    .L10
.L11:
    movl    -4(%ebp), %eax
    movb    $1, mk(%eax)
    movl    -8(%ebp), %eax
    addl    %eax, -4(%ebp)
.L10:
    cmpl    $5000003, -4(%ebp)
    jle    .L11
.L7:
    incl    -8(%ebp)
.L5:
    cmpl    $5000003, -8(%ebp)
    jle    .L6
    leave
    ret
    .size    init, .-init
    .section    .rodata
.LC0:
    .string    "%d"
.LC1:
    .string    "%d\n"
    .text
.globl main
    .type    main, @function
main:
    leal    4(%esp), %ecx
    andl    $-16, %esp
    pushl    -4(%ecx)
    pushl    %ebp
    movl    %esp, %ebp
    pushl    %ecx
    subl    $52, %esp
    call    init
    leal    -20(%ebp), %eax
    movl    %eax, 4(%esp)
    movl    $.LC0, (%esp)
    call    scanf
    jmp    .L15
.L16:
    leal    -24(%ebp), %eax
    movl    %eax, 4(%esp)
    movl    $.LC0, (%esp)
    call    scanf
    movl    $1, -8(%ebp)
    movl    $0, -16(%ebp)
    jmp    .L17
.L18:
    movl    -24(%ebp), %eax
    movzbl    mk(%eax), %eax
    testb    %al, %al
    jne    .L19
    movl    -24(%ebp), %eax
    cmpl    $1, %eax
    jle    .L23
    movl    c+4, %eax
    movl    -8(%ebp), %edx
    imull    %edx, %eax
    movl    %eax, -8(%ebp)
    jmp    .L23
.L19:
    movl    $0, -12(%ebp)
    jmp    .L24
.L25:
    movl    -24(%ebp), %edx
    movl    -16(%ebp), %eax
    movl    prime(,%eax,4), %eax
    movl    %eax, -40(%ebp)
    movl    %edx, %eax
    cltd
    idivl    -40(%ebp)
    movl    %eax, -40(%ebp)
    movl    -40(%ebp), %eax
    movl    %eax, -24(%ebp)
    incl    -12(%ebp)
.L24:
    movl    -24(%ebp), %edx
    movl    -16(%ebp), %eax
    movl    prime(,%eax,4), %eax
    movl    %eax, -40(%ebp)
    movl    %edx, %eax
    cltd
    idivl    -40(%ebp)
    movl    %edx, %eax
    testl    %eax, %eax
    je    .L25
    cmpl    $0, -12(%ebp)
    jle    .L27
    movl    -12(%ebp), %eax
    movl    c(,%eax,4), %edx
    movl    -8(%ebp), %eax
    imull    %edx, %eax
    movl    %eax, -8(%ebp)
.L27:
    incl    -16(%ebp)
.L17:
    movl    pnum, %eax
    cmpl    %eax, -16(%ebp)
    jl    .L18
.L23:
    movl    -8(%ebp), %eax
    movl    %eax, 4(%esp)
    movl    $.LC1, (%esp)
    call    printf
.L15:
    movl    -20(%ebp), %eax
    decl    %eax
    movl    %eax, -20(%ebp)
    movl    -20(%ebp), %eax
    cmpl    $-1, %eax
    jne    .L16
    movl    $0, %eax
    addl    $52, %esp
    popl    %ecx
    popl    %ebp
    leal    -4(%ecx), %esp
    ret
    .size    main, .-main
    .comm    mk,5000004,32
    .comm    prime,1400000,32
    .comm    pnum,4,4
    .comm    c,128,32
    .ident    "GCC: (GNU) 4.1.2 20061115 (prerelease) (Debian 4.1.1-21)"
    .section    .note.GNU-stack,"",@progbits

The code of my first trial of optimization was like:

// prog_asm.c
#include <stdio.h>

// Define the problem scale
#define MAXN 5000004
#define MAXP 350000
#define MAXR 2236

char mk[MAXN];         // Prime bits
int prime[MAXP], pnum; // Primes
int c[32];             // Array for generating answer

void init()
{
    int i, j;
    for (i = 0; i < 32; i++)
        c[i] = (i + 2) * (i + 2) * (i + 1) * (i + 1) >> 2;
    for (i = 2; i < MAXN; i++)
        if (!mk[i])
        {
            prime[pnum++] = i;
            if (i < MAXR)
                for (j = i * i; j < MAXN; j += i)
                    mk[j] = 1;
        }
}

int main()
{
    int i, j, ncase, n, res;
    init();
    scanf("%d", &ncase);
    while (ncase--)
    {
        scanf("%d", &n);
        res = 1;
        for (i = 0; i < pnum; i++)
            if (!mk[n])
            {
                if (n > 1)
                    res *= c[1];
                break;
            }
            else
            {
                j = 0;
                while (!(n % prime[i]))
                {
                    asm("movl    %eax, -24(%ebp)");
                    j++;
                }
                if (j > 0)
                    res *= c[j];
            }
        printf("%d\n", res);
    }
    return 0;
}
.file    "prog_asm.c"
.text
.globl init
.type    init, @function
init:
    pushl    %ebp
    movl    %esp, %ebp
    subl    $16, %esp
    movl    $0, -8(%ebp)
    jmp    .L2
.L3:
    movl    -8(%ebp), %ecx
    movl    -8(%ebp), %edx
    addl    $2, %edx
    movl    -8(%ebp), %eax
    addl    $2, %eax
    imull    %eax, %edx
    movl    -8(%ebp), %eax
    incl    %eax
    imull    %eax, %edx
    movl    -8(%ebp), %eax
    incl    %eax
    imull    %edx, %eax
    sarl    $2, %eax
    movl    %eax, c(,%ecx,4)
    incl    -8(%ebp)
.L2:
    cmpl    $31, -8(%ebp)
    jle    .L3
    movl    $2, -8(%ebp)
    jmp    .L5
.L6:
    movl    -8(%ebp), %eax
    movzbl    mk(%eax), %eax
    testb    %al, %al
    jne    .L7
    movl    pnum, %eax
    movl    -8(%ebp), %edx
    movl    %edx, prime(,%eax,4)
    incl    %eax
    movl    %eax, pnum
    cmpl    $2235, -8(%ebp)
    jg    .L7
    movl    -8(%ebp), %eax
    imull    -8(%ebp), %eax
    movl    %eax, -4(%ebp)
    jmp    .L10
.L11:
    movl    -4(%ebp), %eax
    movb    $1, mk(%eax)
    movl    -8(%ebp), %eax
    addl    %eax, -4(%ebp)
.L10:
    cmpl    $5000003, -4(%ebp)
    jle    .L11
.L7:
    incl    -8(%ebp)
.L5:
    cmpl    $5000003, -8(%ebp)
    jle    .L6
    leave
    ret
    .size    init, .-init
    .section    .rodata
.LC0:
    .string    "%d"
.LC1:
    .string    "%d\n"
    .text
.globl main
    .type    main, @function
main:
    leal    4(%esp), %ecx
    andl    $-16, %esp
    pushl    -4(%ecx)
    pushl    %ebp
    movl    %esp, %ebp
    pushl    %ecx
    subl    $52, %esp
    call    init
    leal    -20(%ebp), %eax
    movl    %eax, 4(%esp)
    movl    $.LC0, (%esp)
    call    scanf
    jmp    .L15
.L16:
    leal    -24(%ebp), %eax
    movl    %eax, 4(%esp)
    movl    $.LC0, (%esp)
    call    scanf
    movl    $1, -8(%ebp)
    movl    $0, -16(%ebp)
    jmp    .L17
.L18:
    movl    -24(%ebp), %eax
    movzbl    mk(%eax), %eax
    testb    %al, %al
    jne    .L19
    movl    -24(%ebp), %eax
    cmpl    $1, %eax
    jle    .L23
    movl    c+4, %eax
    movl    -8(%ebp), %edx
    imull    %edx, %eax
    movl    %eax, -8(%ebp)
    jmp    .L23
.L19:
    movl    $0, -12(%ebp)
    jmp    .L24
.L25:
#APP
    movl    %eax, -24(%ebp)
#NO_APP
    incl    -12(%ebp)
.L24:
    movl    -24(%ebp), %edx
    movl    -16(%ebp), %eax
    movl    prime(,%eax,4), %eax
    movl    %eax, -40(%ebp)
    movl    %edx, %eax
    cltd
    idivl    -40(%ebp)
    movl    %edx, %eax
    testl    %eax, %eax
    je    .L25
    cmpl    $0, -12(%ebp)
    jle    .L27
    movl    -12(%ebp), %eax
    movl    c(,%eax,4), %edx
    movl    -8(%ebp), %eax
    imull    %edx, %eax
    movl    %eax, -8(%ebp)
.L27:
    incl    -16(%ebp)
.L17:
    movl    pnum, %eax
    cmpl    %eax, -16(%ebp)
    jl    .L18
.L23:
    movl    -8(%ebp), %eax
    movl    %eax, 4(%esp)
    movl    $.LC1, (%esp)
    call    printf
.L15:
    movl    -20(%ebp), %eax
    decl    %eax
    movl    %eax, -20(%ebp)
    movl    -20(%ebp), %eax
    cmpl    $-1, %eax
    jne    .L16
    movl    $0, %eax
    addl    $52, %esp
    popl    %ecx
    popl    %ebp
    leal    -4(%ecx), %esp
    ret
    .size    main, .-main
    .comm    mk,5000004,32
    .comm    prime,1400000,32
    .comm    pnum,4,4
    .comm    c,128,32
    .ident    "GCC: (GNU) 4.1.2 20061115 (prerelease) (Debian 4.1.1-21)"
    .section    .note.GNU-stack,"",@progbits

The difference part is the assembly code:

--- prog_purec.s
+++ prog_asm.s
@@ -1,4 +1,4 @@
-.file    "prog_purec.c"
+.file    "prog_asm.c"
 .text
 .globl init
 .type    init, @function
@@ -108,16 +108,9 @@
     movl    $0, -12(%ebp)
     jmp    .L24
 .L25:
-    movl    -24(%ebp), %edx
-    movl    -16(%ebp), %eax
-    movl    prime(,%eax,4), %eax
-    movl    %eax, -40(%ebp)
-    movl    %edx, %eax
-    cltd
-    idivl    -40(%ebp)
-    movl    %eax, -40(%ebp)
-    movl    -40(%ebp), %eax
+#APP
     movl    %eax, -24(%ebp)
+#NO_APP
     incl    -12(%ebp)
 .L24:
     movl    -24(%ebp), %edx
     movl    -16(%ebp), %eax
     movl    prime(,%eax,4), %eax
     movl    %eax, -40(%ebp)
     movl    %edx, %eax
     cltd
     idivl    -40(%ebp)
     movl    %edx, %eax
     testl    %eax, %eax
     je    .L25

At first glance, the optimization seems to be successful. In the generated assembly, the code is reduced from more than 10 lines to 2 lines, and an integer division is omitted. But actually, the calculation result was wrong. Why a performance tuning could lead to a wrong result? I was confused. After a long time of debugging, I finally found the reason.

The root cause is that the compiler has generated code movl %edx, %eax to test whether the reminder is 0 after executing the division idivl -40(%ebp), this operation has overwritten the value of %eax which is the quotient of the division. So, how do we handle this optimization correctly?

After thinking carefully, we could avoid the damn move action mentioned above, the quotient can be correctly retained. The correct code optimization is as follows:

// prog_correct.c
#include <stdio.h>

// Define the problem scale
#define MAXN 5000004
#define MAXP 350000
#define MAXR 2236

char mk[MAXN];         // Prime bits
int prime[MAXP], pnum; // Primes
int c[32];             // Array for generating answer

void init()
{
    int i, j;
    for (i = 0; i < 32; i++)
        c[i] = (i + 2) * (i + 2) * (i + 1) * (i + 1) >> 2;
    for (i = 2; i < MAXN; i++)
        if (!mk[i])
        {
            prime[pnum++] = i;
            if (i < MAXR)
                for (j = i * i; j < MAXN; j += i)
                    mk[j] = 1;
        }
}

int main()
{
    int i, j, ncase, n, res;
    init();
    scanf("%d", &ncase);
    while (ncase--)
    {
        scanf("%d", &n);
        res = 1;
        for (i = 0; i < pnum; i++)
            if (!mk[n])
            {
                if (n > 1)
                    res *= c[1];
                break;
            }
            else
            {
                j = 0;
                asm(
                    "jmp    JDG\n"
                    "WIL:\n"
                    "movl    %eax, -24(%ebp)\n"
                    "incl    -12(%ebp)\n"
                    "JDG:\n"
                    "movl    -24(%ebp), %edx\n"
                    "movl    -16(%ebp), %eax\n"
                    "movl    prime(,%eax,4), %eax\n"
                    "movl    %eax, -40(%ebp)\n"
                    "movl    %edx, %eax\n"
                    "cltd\n"
                    "idivl    -40(%ebp)\n"
                    "testl    %edx, %edx\n"
                    "je     WIL\n");
                if (j > 0)
                    res *= c[j];
            }
        printf("%d\n", res);
    }
    return 0;
}

or in intel style

// prog_correct_intel.c
#include <stdio.h>

// Define the problem scale
#define MAXN 5000004
#define MAXP 350000
#define MAXR 2236

char mk[MAXN];         // Prime bits
int prime[MAXP], pnum; // Primes
int c[32];             // Array for generating answer

void init()
{
    int i, j;
    for (i = 0; i < 32; i++)
        c[i] = (i + 2) * (i + 2) * (i + 1) * (i + 1) >> 2;
    for (i = 2; i < MAXN; i++)
        if (!mk[i])
        {
            prime[pnum++] = i;
            if (i < MAXR)
                for (j = i * i; j < MAXN; j += i)
                    mk[j] = 1;
        }
}

int main()
{
    int i, j, ncase, n, res;
    init();
    scanf("%d", &ncase);
    while (ncase--)
    {
        scanf("%d", &n);
        res = 1;
        for (i = 0; i < pnum; i++)
            if (!mk[n])
            {
                if (n > 1)
                    res *= c[1];
                break;
            }
            else
            {
                j = 0;
                __asm {
                UL1:
                    mov ecx, i
                    mov eax, n
                    cdq
                    idiv prime[ecx*4]
                    test edx, edx
                    jne UL2
                    mov n, eax
                    inc j
                    jmp UL1
                UL2:
                }
                if (j > 0) res *= c[j];
            }
        printf("%d\n", res);
    }
    return 0;
}

Published Nov 8, 2008

Flying code monkey