#pragma once #include "tristate.hpp" #include "base.hpp" #include "constant.hpp" namespace __lava { struct [[gnu::packed]] zero_storage { using underlying_type = uint8_t; enum value_type : underlying_type { zero = 0, nonzero = 1, unknown = 2, bottom = 3 } value; constexpr zero_storage() : value( unknown ) {} constexpr zero_storage( value_type v ) : value( v ) {} constexpr zero_storage( uint8_t v ) : zero_storage( value_type( v ) ) {} }; struct zero : tagged_storage< zero_storage >, domain_mixin< zero > { using ze = zero_storage; using zv = zero; using zr = const zero &; using ref = domain_ref< zero >; zero( ze::value_type value ) : tagged_storage< ze >( value ) {} zero( ze::underlying_type value ) : zero( ze::value_type( value ) ) {} zero( const zero &o ) : zero( o->value ) {} zero( void *v, __dios::construct_shared_t s ) : tagged_storage< ze >( v, s ) {} __inline ze::value_type value() const { return this->get().value; } __inline bool is_top() const { return value() == ze::unknown; } __inline bool is_bottom() const { return value() == ze::bottom; } __inline bool is_zero() const { return value() == ze::zero; } __inline bool is_nonzero() const { return value() == ze::nonzero; } static zero top() { return ze::unknown; } static zero bottom() { return ze::bottom; } template< typename type > static zv lift( const type &v ) { if constexpr ( std::is_integral_v< type > || std::is_pointer_v< type > ) return v ? ze::nonzero : ze::zero; else fail(); } static constant lower( zr i ) { fail(); } template< typename type > static zero any() { return ze::unknown; } void intersect( ze::value_type b ) { auto &val = get().value; if ( val == ze::zero ) val = b != ze::nonzero ? ze::zero : ze::bottom; else if ( val == ze::nonzero ) val = b != ze::zero ? ze::nonzero : ze::bottom; else // value() == ze::unknown val = b; if ( is_bottom() ) __vm_cancel(); } static ze::value_type invert( ze::value_type a ) { if ( a == ze::nonzero ) return ze::zero; if ( a == ze::zero ) return ze::nonzero; return ze::unknown; } static void assume( zero &z, bool constraint ) { z.intersect( constraint ? ze::nonzero : ze::zero ); if ( z.is_bottom() ) __vm_cancel(); } static tristate to_tristate( zr z ) { return { ( tristate::value_t ) z->value }; } static zv add( zr a, zr b ) { auto r = a->value + b->value; return r < ze::unknown ? r : ze::unknown; } static zv mul( zr a, zr b ) { auto r = a->value * b->value; return r < ze::unknown ? r : ze::unknown; } static zv div( zr a, zr b ) { auto fault = _VM_Fault::_VM_F_Integer; if ( b->value != ze::nonzero ) { // division by 'zero' or 'unknown' __dios_fault( fault, "division by zero" ); return ze::unknown; } return a->value ? ze::unknown : ze::zero; } static zv shift( zr a, zr b ) { if ( a.is_zero() ) // shifting 'zero' always results in 'zero' return ze::zero; return b.is_zero() ? ze::nonzero : ze::unknown; } static zv bwand( zr a, zr b ) { return a->value * b->value ? ze::unknown : ze::zero; } static zv bwor( zr a, zr b ) { if ( a.is_nonzero() || b.is_nonzero() ) return ze::nonzero; if ( a.is_zero() && b.is_zero() ) return ze::zero; return ze::unknown; } static zv bwxor( zr a, zr b ) { auto r = a->value + b->value; return r < ze::unknown ? r : ze::unknown; } template< bool equality > static zv eq( zr a, zr b ) { // two Z if ( a.is_zero() && b.is_zero() ) return equality ? ze::nonzero : ze::zero; // one Z, one N if ( a->value + b->value == ze::nonzero ) return equality ? ze::zero : ze::nonzero; return ze::unknown; } static zv ugt( zr a, zr b ) { if ( a.is_zero() ) return ze::zero; if ( a.is_nonzero() && b.is_zero() ) return ze::nonzero; return ze::unknown; } static zv uge( zr a, zr b ) { if ( b.is_zero() ) return ze::nonzero; if ( b.is_nonzero() && a.is_zero() ) return ze::zero; return ze::unknown; } template< bool strict > static zv scmp( zr a, zr b ) { bool both_zero = a.is_zero() && b.is_zero(); if constexpr ( strict ) { return both_zero ? ze::zero : ze::unknown; } else { return both_zero ? ze::nonzero : ze::unknown; } } static void memop_check( zr p ) { if ( p->value != ze::nonzero ) __dios_fault( _VM_Fault::_VM_F_Memory, "null pointer dereference" ); // TODO out of bounds } // Binary ops static zv op_add ( zr a, zr b ) { return add( a, b ); } static zv op_sub ( zr a, zr b ) { return add( a, b ); } static zv op_mul ( zr a, zr b ) { return mul( a, b ); } static zv op_sdiv( zr a, zr b ) { return div( a, b ); } static zv op_udiv( zr a, zr b ) { return div( a, b ); } static zv op_srem( zr a, zr b ) { return div( a, b ); } static zv op_urem( zr a, zr b ) { return div( a, b ); } // TODO: Binary ops with floating values // Bitwise binary ops static zv op_shl ( zr a, zr b ) { return shift( a, b ); } static zv op_ashr( zr a, zr b ) { return shift( a, b ); } static zv op_lshr( zr a, zr b ) { return shift( a, b ); } static zv op_and ( zr a, zr b ) { return bwand( a, b ); } static zv op_or ( zr a, zr b ) { return bwor ( a, b ); } static zv op_xor ( zr a, zr b ) { return bwxor( a, b ); } // Comparison ops static zv op_eq ( zr a, zr b ) { return eq< true >( a, b ); } static zv op_ne ( zr a, zr b ) { return eq< false >( a, b ); } static zv op_ugt( zr a, zr b ) { return ugt( a, b ); } static zv op_uge( zr a, zr b ) { return uge( a, b ); } static zv op_ult( zr a, zr b ) { return ugt( b, a ); } static zv op_ule( zr a, zr b ) { return uge( b, a ); } static zv op_sgt( zr a, zr b ) { return scmp< true >( a, b ); } static zv op_sge( zr a, zr b ) { return scmp< false >( a, b ); } static zv op_slt( zr a, zr b ) { return scmp< true >( a, b ); } static zv op_sle( zr a, zr b ) { return scmp< false >( a, b ); } static zv op_zfit ( zr z, bw ) { return z.is_zero() ? ze::zero : ze::unknown; } static zv op_trunc( zr z, bw ) { return z.is_zero() ? ze::zero : ze::unknown; } static zv op_sext ( zr z, bw ) { return z.clone(); } static zv op_zext ( zr z, bw ) { return z.clone(); } template< typename scal > static void op_store( zr p, const scal&, bw ) { memop_check( p ); } static zv op_load ( zr p, bw ) { memop_check( p ); return ze::unknown; } static zv op_concat( zr a, zr b ) { if ( a.is_nonzero() || b.is_nonzero() ) return ze::nonzero; if ( a.is_zero() && b.is_zero() ) return ze::zero; return ze::unknown; } static zv op_extract( zr z, bw, bw ) { return z.is_zero() ? ze::zero : ze::unknown; } static void b_add( zr r, ref a, ref b ) { // r = N (a + b = N) if ( r.is_nonzero() ) { if ( a.is_zero() ) b.intersect( ze::nonzero ); if ( b.is_zero() ) a.intersect( ze::nonzero ); } // r = Z (a + b = Z) else if ( r.is_zero() ) { a.intersect( b->value ); b.intersect( a->value ); } } static void b_mul( zr r, ref a, ref b ) { // r = N (a * b = N) if ( r.is_nonzero() ) { a.intersect( ze::nonzero ); b.intersect( ze::nonzero ); } // r = Z (a * b = Z) else if ( r.is_zero() ) { if ( a.is_nonzero() ) b.intersect( ze::zero ); if ( b.is_nonzero() ) a.intersect( ze::zero ); } } static void b_div( zr r, ref a, ref b ) { // r = N (a / b = N) if ( r.is_nonzero() ) { a.intersect( ze::nonzero ); } b.intersect( ze::nonzero ); } static void b_shift( zr r, ref a, ref b ) { // (a << b) = N if ( r.is_nonzero() ) { a.intersect( ze::nonzero ); } // (a << b) = Z else if ( r.is_zero() ) { if ( a.is_nonzero() ) b.intersect( ze::nonzero ); if ( b.is_zero() ) a.intersect( ze::zero ); } } static void b_bwand( zr r, ref a, ref b ) { // (a & b) = N if ( r.is_nonzero() ) { a.intersect( ze::nonzero ); b.intersect( ze::nonzero ); } } static void b_bwor( zr r, ref a, ref b ) { // (a | b) = N if ( r.is_nonzero() ) { if ( a.is_zero() ) b.intersect( ze::nonzero ); if ( b.is_zero() ) a.intersect( ze::nonzero ); // (a | b) = Z } else if ( r.is_zero() ) { a.intersect( ze::zero ); b.intersect( ze::zero ); } } template< bool equality > static void b_eq( zr r, ref a, ref b ) { auto result = equality ? r->value : invert( r->value ); // result = N (a == b) if ( result == ze::nonzero ) { a.intersect( b->value ); b.intersect( a->value ); } // result = Z (a != b) else if ( result == ze::zero ) { if ( a.is_zero() ) b.intersect( ze::nonzero ); if ( b.is_zero() ) a.intersect( ze::nonzero ); } } static void b_ugt( zr r, ref a, ref b ) { // r = N (a > b) if ( r.is_nonzero() ) { a.intersect( ze::nonzero ); } // r = Z (a <= b) else if ( r.is_zero() ) { if ( b.is_zero() ) a.intersect( ze::zero ); if ( a.is_nonzero() ) b.intersect( ze::nonzero ); } } static void b_uge( zr r, ref a, ref b ) { // r = N (a >= b) if ( r.is_nonzero() ) { if ( a.is_zero() ) b.intersect( ze::zero ); if ( b.is_nonzero() ) a.intersect( ze::nonzero ); } // r = Z (a < b) else if ( r.is_zero() ) { b.intersect( ze::nonzero ); } } template< bool strict > static void b_scmp( zr r, ref a, ref b ) { auto result = strict ? r->value : invert( r->value ); // (a > b) or (a < b) if ( result == ze::nonzero ) { if ( a.is_zero() ) b.intersect( ze::nonzero ); if ( b.is_zero() ) a.intersect( ze::nonzero ); } } // Binary ops static void bop_add ( zr r, zr a, zr b ) { b_add( r, a, b ); } static void bop_sub ( zr r, zr a, zr b ) { b_add( r, a, b ); } static void bop_mul ( zr r, zr a, zr b ) { b_mul( r, a, b ); } static void bop_sdiv( zr r, zr a, zr b ) { b_div( r, a, b ); } static void bop_udiv( zr r, zr a, zr b ) { b_div( r, a, b ); } static void bop_srem( zr r, zr a, zr b ) { b_div( r, a, b ); } static void bop_urem( zr r, zr a, zr b ) { b_div( r, a, b ); } // TODO: Binary ops with floating values // Bitwise binary ops static void bop_shl ( zr r, zr a, zr b ) { b_shift( r, a, b ); } static void bop_ashr( zr r, zr a, zr b ) { b_shift( r, a, b ); } static void bop_lshr( zr r, zr a, zr b ) { b_shift( r, a, b ); } static void bop_and ( zr r, zr a, zr b ) { b_bwand( r, a, b ); } static void bop_or ( zr r, zr a, zr b ) { b_bwor ( r, a, b ); } static void bop_xor ( zr r, zr a, zr b ) { b_eq< false >( r, a, b ); } // Comparison ops static void bop_eq ( zr r, zr a, zr b ) { b_eq< true >( r, a, b ); } static void bop_ne ( zr r, zr a, zr b ) { b_eq< false >( r, a, b ); } static void bop_ugt( zr r, zr a, zr b ) { b_ugt( r, a, b ); } static void bop_uge( zr r, zr a, zr b ) { b_uge( r, a, b ); } static void bop_ult( zr r, zr a, zr b ) { b_ugt( r, b, a ); } static void bop_ule( zr r, zr a, zr b ) { b_uge( r, b, a ); } static void bop_sgt( zr r, zr a, zr b ) { b_scmp< true >( r, a, b ); } static void bop_sge( zr r, zr a, zr b ) { b_scmp< false >( r, a, b ); } static void bop_slt( zr r, zr a, zr b ) { b_scmp< true >( r, a, b ); } static void bop_sle( zr r, zr a, zr b ) { b_scmp< false >( r, a, b ); } static void bop_zfit( zr r, ref a ) { a->value = r->value; } }; }