Branchfree Saturating Arithmetic
Saturated arithmetic is arithmetic where values "saturate" instead of overflowing or underflowing. An overflowed value will be converted to the maximum representable, and similarly an underflowed value will be converted to the minimum representable by the type. This is useful in areas like computer graphics, where overflowing in calculations can result in unsightly glitches in images.
The difficulty with implementing saturated arithmetic is creating all the checks for overflow. The naive method requires several branches in the code. Since jumps tend to be slow to execute, the resulting math executes slowly. For performance, branch-free code is suggested. Another problem is that the overflow or carry flags are not directly accessible from high level languages like C. This means we may need rather complex round-about methods to determine whether or not an overflow occurred. Such code is difficult to make fast.
Unsigned Saturating Arithmetic
The simplest case is saturated arithmetic with unsigned integers. There, the minimal value is zero, and the maximal value has all bits set. The simplicity of the bit-representations of these boundaries means we can use efficient bit-tricks in the algorithms.
Before we start, we will define some simple typedefs in order to make the bit-depths of the various types explicit.
#include <limits.h>
typedef unsigned u32b;
typedef unsigned long long u64b;
typedef __uint128_t u128b;
typedef signed s32b;
typedef signed long long s64b;
typedef __int128_t s128b;
Addition is easy. If the result is smaller than either of the initial values (and we only need to check one of them), then we have overflowed. If so, we go to the all-bits set state. This can be done by logical-oring with -1. Since 1 is the "true" value returned by Boolean operations, all we need is to negate it to create the bit-mask.
u32b sat_addu32b(u32b x, u32b y)
{
u32b res = x + y;
res |= -(res < x);
return res;
}
Subtraction is slightly harder. Here, we want to set the value to zero if it underflows. We can do that by logical-anding with 0. If it doesn't underflow we need to logical-and with all bits set, i.e. -1. The test for underflow is similar to that for overflow with addition. The subtlety is that we need to invert it because the mask is the other way around. The resulting code looks like:
u32b sat_subu32b(u32b x, u32b y)
{
u32b res = x - y;
res &= -(res <= x);
return res;
}
Unsigned saturating division is the easiest of all. Here, there is no possible way for the result to overflow or underflow, and the code is identical to normal unsigned division. Note that division by zero is undefined, so we assume that property with saturating arithmetic as well. (Since unsigned integers don't have a signed zero like IEEE floating point numbers, we don't know the sign of the resulting infinity, and thus which of the bounding values to use.)
u32b sat_divu32b(u32b x, u32b y)
{
/* No overflow possible */
return x / y;
}
Multiplication is a bit more difficult. Here, the simplest code is obtained by using a type twice the width of the original unsigned integer. This gives us access to the overflowing bits, if there are any. By using the "convert to Boolean" pseudo-operator !!, we can check to see if overflow has happened. If it has, then we can do that same thing we did with addition, and use a logical-or with a mask with all bits set to create the correct result.
u32b sat_mulu32b(u32b x, u32b y)
{
u64b res = (u64b) x * (u64b) y;
u32b hi = res >> 32;
u32b lo = res;
return lo | -!!hi;
}
The above code has been written for 32 bit integers. The 64 bit version requires a 128 bit type in order to access the high bits after a multiplication. Such a type isn't in standard C, but fortunately gcc provides one for us. As can be seen, only the code for multiplication changes, and there only by using the different intermediate type.
u64b sat_addu64b(u64b x, u64b y)
{
u64b res = x + y;
res |= -(res < x);
return res;
}
u64b sat_subu64b(u64b x, u64b y)
{
u64b res = x - y;
res &= -(res <= x);
return res;
}
u64b sat_divu64b(u64b x, u64b y)
{
/* No overflow possible */
return x / y;
}
u64b sat_mulu64b(u64b x, u64b y)
{
u128b res = (u128b) x * (u128b) y;
u64b hi = res >> 64;
u64b lo = res;
return lo | -!!hi;
}
Unfortunately, the compiler isn't quite as smart as we would like. The above can be implemented in a completely branchless manner. However, in order to do that the compiler likes to use the setcc-type instructions. Instead, we can use the carry flag directly via the sbb instruction. If the carry is set, then sbb reg, reg will set that register to all-ones. If it is not set, then that register will be all zeros. This is perfect for the masks we use for bounds-setting.
Addition can use the carry flag directly via this sbb trick:
.globl sat_addu32b
.type sat_addu32b,@function
sat_addu32b:
add %esi, %edi
sbb %eax, %eax
or %edi, %eax
retq
.size sat_addu32b, .-sat_addu32b
Subtraction can also:
.globl sat_subu32b
.type sat_subu32b,@function
sat_subu32b:
sub %esi, %edi
cmc
sbb %eax, %eax
and %edi, %eax
retq
.size sat_subu32b, .-sat_subu32b
However, it is probably better to use a conditional move instead. Notice how the first two instructions can execute overlapped in the following due to their lack of dependencies. This should result in faster code.
.globl sat_subu32b
.type sat_subu32b,@function
sat_subu32b:
xor %eax, %eax
sub %esi, %edi
cmovnc %edi, %eax
retq
.size sat_subu32b, .-sat_subu32b
The multiplication instruction is also extremely helpful in setting the carry flag on overflow. This means we don't need to access the high-bits at all, and can use the same trick as in the addition code.
.globl sat_mulu32b
.type sat_mulu32b,@function
sat_mulu32b:
mov %edi, %eax
mul %esi
sbb %edi, %edi
or %edi, %eax
retq
.size sat_mulu32b, .-sat_mulu32b
Finally, for completeness, the division code is:
.globl sat_divu32b
.type sat_divu32b,@function
sat_divu32b:
mov %edi, %eax
xor %edx, %edx
div %esi
retq
.size sat_divu32b, .-sat_divu32b
To convert the above to use 64 bit unsigned integers just requires using 64 bit registers. The only subtlety is that since operations on 32 bit registers automatically zero their upper halves, we can use 32 bit xor operations to clear a full 64 bit register.
.globl sat_addu64b
.type sat_addu64b,@function
sat_addu64b:
add %rsi, %rdi
sbb %rax, %rax
or %rdi, %rax
retq
.size sat_addu64b, .-sat_addu64b
.globl sat_subu64b
.type sat_subu64b,@function
sat_subu64b:
xor %eax, %eax
sub %rsi, %rdi
cmovnc %rdi, %rax
retq
.size sat_subu64b, .-sat_subu64b
.globl sat_mulu64b
.type sat_mulu64b,@function
sat_mulu64b:
mov %rdi, %rax
mul %rsi
sbb %rdi, %rdi
or %rdi, %rax
retq
.size sat_mulu64b, .-sat_mulu64b
.globl sat_divu64b
.type sat_divu64b,@function
sat_divu64b:
mov %rdi, %rax
xor %edx, %edx
div %rsi
retq
.size sat_divu64b, .-sat_divu64b
Signed Saturating Arithmetic
Signed saturating arithmetic is much harder to implement than unsigned. The reasons are that firstly signed overflow is undefined in the C standard. This means that all of the arithmetic needs to be emulated by unsigned types. Secondly, the bounds are not as simple at the bit level as in the unsigned case. Thus the code needs to use or generate INT_MIN or INT_MAX of the requisite type. Finally, the signed overflow is signalled by the overflow flag rather than the carry flag. Thus sbb tricks cannot be used.
One way to implement signed saturating addition is to use a larger type. We can then manually check for overflow, and correct the result accordingly. For 32 bit integers, we then have code looking like:
s32b sat_adds32b(s32b x, s32b y)
{
s64b res = (s64b) x + (s64b) y;
if (res < INT_MIN) res = INT_MIN;
if (res > INT_MAX) res = INT_MAX;
return res;
}
Of course, this branchy code is quite inefficient. We would like to remove them and use some sort of bit-based logical operations instead. The first step is to be able to use same-sized integers in the calculation. To do that, we need some way to notice when overflow happens. In the case of addition, overflow can only occur if both arguments have the same sign. If this is the case and the result has the opposite sign to either of the two arguments, then we have had an overflow (or underflow). We can then use the sign bits to determine the correct results. Code that implements the function that way looks like:
s32b sat_adds32b(s32b x, s32b y)
{
u32b ux = x;
u32b uy = y;
u32b res = ux + uy;
/* Only if same sign, can they overflow */
if (!((ux ^ uy) >> 31))
{
/* Is the result a different sign? */
if ((res ^ ux) >> 31)
{
/* Saturate */
res = (ux & 0x80000000) ? 0x80000000:0x7fffffff;
}
}
return res;
}
The next step is to start removing the branches. On overflow, we can use a conditional move. To calculate that result we can use the trick that
overflowed_res = (ux >> 31) + INT_MAX;
which is obviously branchfree. Next we can combine the two if statements into one single one that tests the sign bit using the cmovns instruction. The resulting code looks like:
s32b sat_adds32b(s32b x, s32b y)
{
u32b ux = x;
u32b uy = y;
u32b res = ux + uy;
/* Calculate overflowed result. (Don't change the sign bit of ux) */
ux = (ux >> 31) + INT_MAX;
/* Force compiler to use cmovns instruction */
if ((s32b) ((ux ^ uy) | ~(uy ^ res)) >= 0)
{
res = ux;
}
return res;
}
Subtraction is similar. However, there only if the signs of the two arguments are opposite can we overflow.
s32b sat_subs32b(s32b x, s32b y)
{
u32b ux = x;
u32b uy = y;
u32b res = ux - uy;
ux = (ux >> 31) + INT_MAX;
/* Force compiler to use cmovns instruction */
if ((s32b)((ux ^ uy) & (ux ^ res)) < 0)
{
res = ux;
}
return res;
}
Division, unlike in the unsigned case, can overflow in signed arithmetic. The only case where it happens is when you multiply INT_MIN by -1. The correct answer would be INT_MAX + 1 , however this obviously doesn't fit inside the signed integer. Another problem is that if we try to do the division with these arguments an exception will be raised. Thus we cannot use the same technique as in addition and subtraction where we correct the overflow after it has occurred.
We instead need to alter the arguments in the special case where they would otherwise overflow. This isn't that hard: by adding one to the argument that starts as INT_MIN we produce the correct answer. Thus the algorithm needs to quickly check for the special case. We then add the true/false (1 or 0) Boolean result to the dividend to prevent overflow:
s32b sat_divs32b(s32b x, s32b y)
{
/* Only one way to overflow, so test for and prevent it. */
x += !((y + 1) | ((u32b) x + INT_MIN));
return x / y;
}
Finally, that leaves multiplication. Here, overflow occurs if the high half of the result is not equal to the sign-extended low half. The overflowed value can be computed either from the sign bit of the high half of the result, or from the xor of the sign bits of the two arguments. It appears that the latter is faster, as we can overlap the calculation of the overflow result with the test for overflow.
s32b sat_muls32b(s32b x, s32b y)
{
s64b res = (s64b) x * (s64b) y;
u32b res2 = ((u32b) (x ^ y) >> 31) + INT_MAX;
s32b hi = (res >> 32);
s32b lo = res;
if (hi != (lo >> 31)) res = res2;
return res;
}
The 64 bit code is very similar to the 32 bit code. Again, multiplication is the only major change, where 128 bit arithmetic is required.
s64b sat_adds64b(s64b x, s64b y)
{
u64b ux = x;
u64b uy = y;
u64b res = ux + uy;
ux = (ux >> 63) + LONG_MAX;
/* Force compiler to use cmovns instruction */
if ((s64b) ((ux ^ uy) | ~(uy ^ res)) >= 0)
{
res = ux;
}
return res;
}
s64b sat_subs64b(s64b x, s64b y)
{
u64b ux = x;
u64b uy = y;
u64b res = ux - uy;
ux = (ux >> 63) + LONG_MAX;
/* Force compiler to use cmovns instruction */
if ((s64b)((ux ^ uy) & (ux ^ res)) < 0)
{
res = ux;
}
return res;
}
s64b sat_divs64b(s64b x, s64b y)
{
/* Only one way to overflow, so test for and prevent it. */
x += !((y + 1) | ((u64b) x + LONG_MIN));
return x / y;
}
s64b sat_muls64b(s64b x, s64b y)
{
s128b res = (s128b) x * (s128b) y;
u64b res2 = ((u64b) (x ^ y) >> 63) + LONG_MAX;
s64b hi = (res >> 64);
s64b lo = res;
if (hi != (lo >> 63)) res = res2;
return res;
}
Unfortunately, even though the above can be compiled into nice branch-free code, gcc doesn't do to good of a job at it. Again, we will need to use assembly language to get the most optimal results. Fortunately, in asm we have direct access to the overflow flag. This makes the tests for overflow much simpler for addition and subtraction. By using the cmovo instruction we can correct for overflow when it occurs.
.globl sat_adds32b
.type sat_adds32b,@function
sat_adds32b:
mov %edi, %eax
shr $0x1f, %edi
add $0x7fffffff, %edi
add %esi, %eax
cmovo %edi, %eax
retq
.size sat_adds32b, .-sat_adds32b
.globl sat_subs32b
.type sat_subs32b,@function
sat_subs32b:
mov %edi, %eax
shr $0x1f, %edi
add $0x7fffffff, %edi
sub %esi, %eax
cmovo %edi, %eax
retq
.size sat_subs32b, .-sat_subs32b
Division also has a little room for optimization. There we can use a little trick to combine the Boolean test with the add to the dividend that the compiler misses. By using the neg instruction we can set the carry flag when the overflow test result is non-zero. Then via the sbb instruction we can subtract -1 (or equivalently add 1) to the dividend when required.
.globl sat_divs32b
.type sat_divs32b,@function
sat_divs32b:
mov %edi, %eax
lea 0x1(%rsi), %edx
add $0x80000000, %edi
or %edx, %edi
cdq
neg %edi
sbb $-1, %eax
idiv %esi
retq
.size sat_divs32b, .-sat_divs32b
Finally we have signed saturating multiplication. This is much easier to implement in assembly language. Fortunately for us, the imul instruction will set the carry flag on signed overflow. Thus we can use a similar algorithm to that used for addition and subtraction rather than that used in C.
.globl sat_muls32b
.type sat_muls32b,@function
sat_muls32b:
mov %edi, %eax
xor %esi, %edi
shr $0x1f, %edi
add $0x7fffffff, %edi
imul %esi
cmovc %edi, %eax
retq
.size sat_muls32b, .-sat_muls32b
The 64 bit versions are very similar to their 32 bit counterparts. The only major difference is of course the use of 64 bit registers. However, note that some instructions change due to slightly different mnemonics. Also notice that one cannot use a 64 bit immediate in logical and arithmetic operations. Thus we have a choice of using a spare register and the movabs instruction, or as below to use rip relative addressing.
long_max:
.quad 0x7fffffffffffffff
long_min:
.quad 0x8000000000000000
.globl sat_adds64b
.type sat_adds64b,@function
sat_adds64b:
mov %rdi, %rax
shr $0x3f, %rdi
add long_max(%rip), %rdi
add %rsi, %rax
cmovo %rdi, %rax
retq
.size sat_adds64b, .-sat_adds64b
.globl sat_subs64b
.type sat_subs64b,@function
sat_subs64b:
mov %rdi, %rax
shr $0x3f, %rdi
add long_max(%rip), %rdi
sub %rsi, %rax
cmovo %rdi, %rax
retq
.size sat_subs64b, .-sat_subs64b
.globl sat_muls64b
.type sat_muls64b,@function
sat_muls64b:
mov %rdi, %rax
xor %rsi, %rdi
shr $0x3f, %rdi
add long_max(%rip), %rdi
imul %rsi
cmovc %rdi, %rax
retq
.size sat_muls64b, .-sat_muls64b
.globl sat_divs64b
.type sat_divs64b,@function
sat_divs64b:
mov %rdi, %rax
lea 0x1(%rsi), %rdx
add long_min(%rip), %rdi
or %rdx, %rdi
cqo
neg %rdi
sbb $-1, %rax
idiv %rsi
retq
.size sat_divs64b, .-sat_divs64b
Other Branchy Functions
There are other common cases where branches may appear in basic arithmetic. The most obvious of these are the min and max functions/macros to calculate the minimum or maximum of two values. Fortunately, the compiler is very good at optimizing these. In fact it typically will use a conditional move instead of a branch, which is optimal.
Another common function is abs which calculates the absolute value of a signed integer. There is a very well known trick to implement this efficiently in asm. (In fact it appears as a question in several tests for knowledge about assembly language.) Basically, you can use the sign-extending instructions to obtain a bit-mask that is either all ones or all zeros. Then via the identity -x=~x+1 we can selectively take the negative of the argument. Thus C code that looks like:
s32b abs32(s32b x)
{
if (x >= 0) return x;
return -x;
}
s64b abs64(s64b x)
{
if (x >= 0) return x;
return -x;
}
Can be optimized into
.globl abs32
.type abs32,@function
abs32:
mov %edi, %eax
cdq
xor %edx, %eax
sub %edx, %eax
retq
.size abs32, .-abs32
.globl abs64
.type abs64,@function
abs64:
mov %rdi, %rax
cqo
xor %rdx, %rax
sub %rdx, %rax
retq
.size abs64, .-abs64
One final function that written naively will involve branches is the signum function. This returns 1 for positive numbers, -1 for negative numbers, and zero otherwise. One way to implement this in C in such a way that the compiler is guided to create good code is:
s32b sgn32(s32b x)
{
return (x > 0) - (x < 0);
}
s64b sgn64(s64b x)
{
return (x > 0) - (x < 0);
}
The above tests can be efficiently written in terms of the sbb and setge instructions. However, we can do a little better. By using the neg instruction combined with an add with carry we have the following:
.globl sgn32
.type sgn32,@function
sgn32:
mov %edi, %eax
sar $0x1f, %eax
neg %edi
adc %eax, %eax
retq
.size sgn32, .-sgn32
.globl sgn64
.type sgn64,@function
sgn64:
mov %rdi, %rax
sar $0x3f, %rax
neg %rdi
adc %rax, %rax
retq
.size sgn64, .-sgn64
The above is rather subtle. If the argument is negative, then the sar instruction creates a minus one. The neg instruction will set the carry flag because the argument is non-zero. Finally the add will calculate -1 + -1 + 1 = -1, which is the correct answer. If the argument is zero, then the sar will create a zero. The neg instruction will not set the carry flag. Finally, the addition will calculate 0 + 0 + 0 = 0, which is what we want. The last case has the sar instruction create a zero, and the neg instruction set the carry flag. The result then is 0 + 0 + 1 = 1, which is also right.
If we ignore the ABI, we can do even better. Like with the abs function we can use the sign-extending instructions. Taking the input parameter to be in the A register, and the output in the D register, we have in just three instructions:
# Parameter in %rax, result in %rdx
sgn64_unusualabi:
cqo
neg %rax
adc %rdx, %rdx
|
Comments
John Regehr said...res |= -(res < x);
A version that works for any standard compilation platform and compiles to the same assembly code on widespread architectures can be obtained by changing the line to:
res |= -(u32b)(res < x);
long abs64(long i) {
enum {bits = sizeof(i) * 8 - 1};
long sign = i & 1 << bits;
long mask = sign >> bits;
return (i ^ mask) - mask;
}
s32b sat_adds32b(s32b x, s32b y)
((s32b) ((ux ^ uy) | ~(uy ^ res)) >= 0)
s32b sat_subs32b(s32b x, s32b y)
((s32b)((ux ^ uy) & (ux ^ res)) < 0)
the overflow condition in subs32b(.) is
sign bits
x y res x y res
+ - - 0 1 1
- + + 1 0 0
so, the match condition is (x^y)&(x^res)
read as: "x must be different from y and res"
the overflow condition in adds32b(.) is
sign bits
x y res x y res
+ + - 0 0 1
- - + 1 1 0
the match condition and be written as (res^x)&(res^y)
we can take away the compliment operation
found this website from github riscv spike source code
I'd like know what's this code license, because I'm using it (adapted) into a open and commercial project.
Thanks!
sat_adds32b:
mov %edi, %eax
cltd
add $0x80000000, %edx
add %esi, %eax
cmovo %edx, %eax
retq
.size sat_adds32b, .-sat_adds32b
.type sat_adds32b, @function
.globl sat_adds32b
sat_subs32b:
mov %edi, %eax
cltd
add $0x80000000, %edx
sub %esi, %eax
cmovo %edx, %eax
retq
.size sat_subs32b, .-sat_subs32b
.type sat_subs32b, @function
.globl sat_subs32b
sat_adds64b:
mov %rdi, %rax
cdtq
xor %edi, %edi
bts $63, %rdi
add %rdx, %rdi
add %rsi, %rax
cmovo %rdi, %rax
retq
.size sat_adds64b, .-sat_adds64b
.type sat_adds64b, @function
.globl sat_adds64b
sat_subs64b:
mov %rdi, %rax
cdtq
xor %edi, %edi
bts $0x3f, %rdi
add %rdx, %rdi
sub %rsi, %rax
cmovo %rdi, %rax
retq
.size sat_subs64b, .-sat_subs64b
.type sat_subs64b, @function
.globl sat_subs64b