- [Psy-X] optimized version of half floats

This commit is contained in:
Ilya Shurumov 2021-03-02 00:08:11 +06:00
parent 15ce33b102
commit 6101f2373d

View File

@ -1,55 +1,56 @@
#include "COMMON/half_float.h"
#include <TYPES.H>
// see https://gist.github.com/rygorous/2156668
union FP32
{
uint u;
float f;
struct
{
uint Mantissa : 23;
uint Exponent : 8;
uint Sign : 1;
};
};
union FP16
{
unsigned short u;
struct
{
uint Mantissa : 10;
uint Exponent : 5;
uint Sign : 1;
};
};
half::half(const float x)
{
union
{
float floatI;
unsigned int i;
};
// this is a approximate solution
FP32 f = *(FP32*)&x;
FP32 f32infty = { 255 << 23 };
FP32 f16max = { (127 + 16) << 23 };
FP32 magic = { 15 << 23 };
FP32 expinf = { (255 ^ 31) << 23 };
uint sign_mask = 0x80000000u;
FP16 o = { 0 };
floatI = x;
uint sign = f.u & sign_mask;
f.u ^= sign;
// unsigned int i = *((unsigned int *) &x);
int e = ((i >> 23) & 0xFF) - 112;
int m = i & 0x007FFFFF;
sh = (i >> 16) & 0x8000;
if (e <= 0)
{
// Denorm
m = ((m | 0x00800000) >> (1 - e)) + 0x1000;
sh |= (m >> 13);
}
else if (e == 143)
{
sh |= 0x7C00;
if (m != 0)
{
// NAN
m >>= 13;
sh |= m | (m == 0);
}
}
if (!(f.f < f32infty.u)) // Inf or NaN
o.u = f.u ^ expinf.u;
else
{
m += 0x1000;
if (m & 0x00800000)
{
// Mantissa overflow
m = 0;
e++;
}
if (e >= 31)
{
// Exponent overflow
sh |= 0x7C00;
}
else
{
sh |= (e << 10) | (m >> 13);
}
if (f.f > f16max.f) f.f = f16max.f;
f.f *= magic.f;
}
o.u = f.u >> 13; // Take the mantissa bits
o.u |= sign >> 16;
sh = o.u;
}
half::half(const half& other)
@ -59,38 +60,25 @@ half::half(const half& other)
half::operator float() const
{
union
{
unsigned int s;
float result;
};
FP16 h = { sh };
s = (sh & 0x8000) << 16;
unsigned int e = (sh >> 10) & 0x1F;
unsigned int m = sh & 0x03FF;
static const FP32 magic = { 113 << 23 };
static const uint shifted_exp = 0x7c00 << 13; // exponent mask after shift
FP32 o;
if (e == 0)
{
// +/- 0
if (m == 0) return result;
o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits
uint exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// Denorm
while ((m & 0x0400) == 0)
{
m += m;
e--;
}
e++;
m &= ~0x0400;
}
else if (e == 31)
// handle exponent special cases
if (exp == shifted_exp) // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
else if (exp == 0) // Zero/Denormal?
{
// INF / NAN
s |= 0x7F800000 | (m << 13);
return result;
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}
s |= ((e + 112) << 23) | (m << 13);
return result;
o.u |= (h.u & 0x8000) << 16; // sign bit
return o.f;
}