// -*- mode: C++; indent-tabs-mode: nil; c-basic-offset: 4 -*-

/*
 * (c) 2014-2016 Vladimír Štill
 * (c) 2014, 2015 Petr Ročkai
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#ifdef BRICKS_HAVE_LLVM

#ifndef __STDC_LIMIT_MACROS
#define __STDC_LIMIT_MACROS
#endif
#ifndef __STDC_CONSTANT_MACROS
#define __STDC_CONSTANT_MACROS
#endif

#ifdef DIVINE_RELAX_WARNINGS
DIVINE_RELAX_WARNINGS
#endif

#include <llvm/Linker/Linker.h>
#include <llvm/Bitcode/ReaderWriter.h>
#include <llvm/IR/GlobalValue.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/DiagnosticPrinter.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Support/raw_os_ostream.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/Object/ArchiveWriter.h>
#include <llvm/Object/IRObjectFile.h>
#include <llvm/Transforms/Utils/Cloning.h>

#ifdef DIVINE_RELAX_WARNINGS
DIVINE_UNRELAX_WARNINGS
#endif

#include <brick-assert>
#include <brick-types>
#include <brick-fs>
#include <brick-mmap>

#include <vector>
#include <string>
#include <initializer_list>
#include <unordered_set>
#include <unordered_map>
#include <iostream>
#include <set>
#include <sstream>
#include <algorithm>
#include <experimental/string_view>

#ifndef BRICKS_LLVM_LINK_STRIP_H
#define BRICKS_LLVM_LINK_STRIP_H

namespace std {

using experimental::string_view;

template<>
struct hash< llvm::StringRef > {
    size_t operator()( llvm::StringRef sr ) const {
        return hash< std::string_view >()( std::string_view( sr.data(), sr.size() ) );
    }
};

}

namespace brick {

// forward declaration of friend
namespace t_llvm { struct LinkerTest; };

namespace llvm {

struct Annotation : brick::types::Eq {
    Annotation() = default;
    explicit Annotation( std::string anno ) {
        size_t oldoff = 0, off = 0;
        do {
            off = anno.find( oldoff, '.' );
            _parts.emplace_back( anno.substr( oldoff, off - oldoff) );
            oldoff = off + 1;
        } while ( off != std::string::npos );
    }
    template< typename It >
    Annotation( It begin, It end ) : _parts( begin, end ) { }
    Annotation( std::initializer_list< std::string > parts ) : _parts( parts ) { }

    std::string name() { return _parts.back(); }
    Annotation ns() { return Annotation( _parts.begin(), _parts.end() - 1); }
    std::string toString() {
        std::stringstream ss;
        for ( auto &n : _parts )
            ss << n << ".";
        auto str = ss.str();
        return str.substr( 0, str.size() - 1 );
    }

    bool inNamespace( Annotation ns ) {
        return ns._parts.size() < _parts.size()
            && std::equal( ns._parts.begin(), ns._parts.end(), _parts.begin() );
    }

    Annotation dropNamespace( Annotation ns ) {
        return inNamespace( ns )
             ? Annotation( _parts.begin() + ns.size(), _parts.end() )
             : *this;
    }

    size_t size() const { return _parts.size(); }

    bool operator==( const Annotation &o ) const {
        return o.size() == size() && std::equal( _parts.begin(), _parts.end(), o._parts.begin() );
    }

  private:
    std::vector< std::string > _parts;
};

template< typename Yield >
void enumerateFunctionAnnos( ::llvm::Module &m, Yield yield ) {
    auto annos = m.getNamedGlobal( "llvm.global.annotations" );
    if ( !annos )
        return;
    auto a = ::llvm::cast< ::llvm::ConstantArray >( annos->getOperand(0) );
    for ( int i = 0; i < int( a->getNumOperands() ); i++ ) {
        auto e = ::llvm::cast< ::llvm::ConstantStruct >( a->getOperand(i) );
        if ( auto fn = ::llvm::dyn_cast< ::llvm::Function >( e->getOperand(0)->getOperand(0) ) ) {
            std::string anno = ::llvm::cast< ::llvm::ConstantDataArray >(
                        ::llvm::cast< ::llvm::GlobalVariable >(
                            e->getOperand(1)->getOperand(0) )->getOperand(0)
                    )->getAsCString();
            yield( fn, Annotation( anno ) );
        }
    }
}

template< typename Yield >
void enumerateFunctionAnnosInNs( Annotation ns, ::llvm::Module &m, Yield yield ) {
    enumerateFunctionAnnos( m, [&]( ::llvm::Function *fn, Annotation anno ) {
            if ( anno.inNamespace( ns ) )
                yield( fn, anno.dropNamespace( ns ) );
        } );
}

template< typename Yield >
void enumerateFunctionsForAnno( Annotation anno, ::llvm::Module &m, Yield yield ) {
    enumerateFunctionAnnos( m, [&]( ::llvm::Function *fn, Annotation a ) {
            if ( a == anno )
                yield( fn );
        } );
}

template< typename Yield >
void enumerateFunctionsForAnno( std::string anno, ::llvm::Module &m, Yield yield ) {
    enumerateFunctionsForAnno( Annotation( anno ), m, yield );
}

inline void verifyModule( ::llvm::Module *module ) {
    std::string err;
    ::llvm::raw_string_ostream serr( err );
    if ( ::llvm::verifyModule( *module, &serr ) ) {
        throw std::runtime_error( "Invalid bitcode: " + serr.str() );
    }
}

static void _throwLLVMError( std::error_code ec ) {
    throw std::runtime_error( "LLVM Error: " + ec.message() );
}

inline void writeModule( ::llvm::Module *m, std::string out ) {
    verifyModule( *m );
    std::error_code serr;
    brick::fs::mkFilePath( out );
    ::llvm::raw_fd_ostream fs( out.c_str(), serr, ::llvm::sys::fs::F_None );
    if ( serr )
        _throwLLVMError( serr );
    WriteBitcodeToFile( m, fs );
}


template< typename It >
inline void writeArchive( It begin, It end, std::string out )
{
    namespace object = ::llvm::object;

    fs::TempDir dir( ".brick-llvm-archive-XXXXXX", fs::AutoDelete::Yes, fs::UseSystemTemp::Yes );
    std::vector< std::string > names;
    std::set< std::string > shortnames;
    std::vector< ::llvm::NewArchiveIterator > its;

    for ( auto &mod : ::llvm::iterator_range< It >( begin, end ) )
    {
        names.emplace_back( fs::joinPath( dir, fs::basename( mod->getModuleIdentifier() ) ) );
        auto i = shortnames.emplace( fs::basename( names.back() ) );
        if ( !i.second )
            throw std::runtime_error( "doubly defined member of archive: " + names.back()
                                       + " (" + mod->getModuleIdentifier() + ")" );
        writeModule( &*mod, names.back() );
        its.emplace_back( names.back(), *i.first );
    }

    fs::mkFilePath( out );
    auto res = ::llvm::writeArchive( out, its, true, object::Archive::K_GNU, true ).second;
    if ( res )
        _throwLLVMError( res );
}

struct ArchiveReader
{
    using Archive = ::llvm::object::Archive;
    using Ctx = ::llvm::LLVMContext;

    explicit ArchiveReader( std::shared_ptr< Ctx > ctx ) : _ctx( ctx ) {
        if ( !_ctx )
            _ctx.reset( new ::llvm::LLVMContext() );
    }

    ArchiveReader( std::string filename, std::shared_ptr< Ctx > ctx = nullptr ) :
        ArchiveReader( ctx )
    {
        _data = mmap::MMap( filename, mmap::ProtectMode::Read | mmap::ProtectMode::Private );

        std::error_code ec;
        _archive = std::make_unique< Archive >( ::llvm::MemoryBufferRef(
                        ::llvm::StringRef( _data.left().data(), _data.left().size() ), filename ), ec );
        if ( ec )
            _throwLLVMError( ec );
    }

    ArchiveReader( std::string filename, std::string_view contents, std::shared_ptr< Ctx > ctx = nullptr ) :
        ArchiveReader( ctx )
    {
        std::error_code ec;
        _archive = std::make_unique< Archive >( ::llvm::MemoryBufferRef(
                        ::llvm::StringRef( contents.data(), contents.size() ), filename ), ec );
        if ( ec )
            _throwLLVMError( ec );
    }

    ArchiveReader( std::string filename, std::string &&contents, std::shared_ptr< Ctx > ctx = nullptr ) :
        ArchiveReader( ctx )
    {
        _data = std::move( contents );
        std::error_code ec;
        _archive = std::make_unique< Archive >( ::llvm::MemoryBufferRef( _data.right(), filename ), ec );
        if ( ec )
            _throwLLVMError( ec );
    }

    struct BitcodeIterator : types::Eq
    {
        using Underlying = ::llvm::object::Archive::child_iterator;

        BitcodeIterator( Underlying beg, Underlying end, std::shared_ptr< Ctx > ctx, bool bump = true ) :
            _it( beg ), _end( end ), _ctx( ctx )
        {
            if ( bump )
                _bump();
        }

        BitcodeIterator( BitcodeIterator && ) = default;
        BitcodeIterator( const BitcodeIterator &o ) :
            _it( o._it ), _end( o._end ), _ctx( o._ctx ), _iro( nullptr )
        { }

        ::llvm::Module &operator*() {
            _load( true );
            return _iro->getModule();
        }

        std::unique_ptr< ::llvm::Module > take() {
            _load( true );
            return _iro->takeModule();
        }

        bool operator==( const BitcodeIterator &o ) const { return _it == o._it;  }

        BitcodeIterator &operator++() {
            ++_it;
            _bump();
            return *this;
        }

        BitcodeIterator operator++( int ) {
            BitcodeIterator copy( _it, _end, _ctx, std::move( _iro ) );
            ++(*this);
            return copy;
        }

      private:
        using IRObjectFile = ::llvm::object::IRObjectFile;

        bool _load( bool errorIsFatal  ) {
            if ( !_iro ) {
                auto b = _it->getAsBinary( _ctx.get() );
                if ( !b && errorIsFatal )
                    _throwLLVMError( b.getError() );
                if ( !b )
                    return false;
                _iro.reset( ::llvm::cast< IRObjectFile >( b.get().release() ) );
            }
            return true;
        }

        void _bump() {
            _iro = nullptr;
            while ( _it != _end ) {
                if ( _load( false ) )
                    break;
                ++_it;
            }
        }

        BitcodeIterator( Underlying beg, Underlying end, std::shared_ptr< Ctx > ctx,
                         std::unique_ptr< IRObjectFile > iro ) :
            _it( beg ), _end( end ), _ctx( ctx ), _iro( std::move( iro ) )
        { }

        Underlying _it;
        Underlying _end;
        std::shared_ptr< Ctx > _ctx;
        std::unique_ptr< IRObjectFile > _iro;
    };

    ::llvm::iterator_range< BitcodeIterator > modules() {
        return { modules_begin(), modules_end() };
    }

    BitcodeIterator modules_begin() { return BitcodeIterator( _archive->child_begin(), _archive->child_end(), _ctx ); }
    BitcodeIterator modules_end() { return BitcodeIterator( _archive->child_end(), _archive->child_end(), _ctx ); }

    BitcodeIterator iterator( ::llvm::object::Archive::child_iterator raw ) {
        return BitcodeIterator( raw, _archive->child_end(), _ctx, false );
    }

    const Archive &archive() const { return *_archive; }

  private:
    types::Either< mmap::MMap, std::string > _data;
    std::unique_ptr< Archive > _archive;
    std::shared_ptr< Ctx > _ctx;
};

struct Linker {
    using Module = ::llvm::Module;
    using StringRef = ::llvm::StringRef;

    void link( Module &src ) {
        return link( std::unique_ptr< Module >( ::llvm::CloneModule( &src ) ) );
    }

    void link( std::unique_ptr< Module > src ) {
        ASSERT( src != nullptr );
        src->materializeAllPermanently();
        if ( !_link ) {
            _link.reset( new ::llvm::Linker( src.release(),
                            [this]( const ::llvm::DiagnosticInfo &i ) { this->_error( i ); } ) );
        } else {
            auto r = _link->linkInModule( src.get() );
            if ( r ) {
                std::cerr << "ERROR: while linking '"
                          << src->getModuleIdentifier() << "'" << std::endl;
                UNREACHABLE( "Linker error" );
            }
        }
        _updateUndefined();
    }

    using ModuleVector = std::vector< std::unique_ptr< Module > >;

    template< typename K, typename V >
    struct ModuleMap : std::unordered_map< K, V > {

        std::unique_ptr< Module > findFirst( const std::vector< StringRef > &undef )
        {
            for ( auto u : undef ) {
                auto it = this->find( u );
                if ( it != this->end() )
                    return _takeModule( it->second );
            }
            return nullptr;
        }
    };

    template< typename It >
    struct ModuleFinder : ModuleMap< std::string, It >
    {
        ModuleFinder( It begin, It end )
        {
            for ( auto it = begin; it != end; ++it )
                _symbols( _getModule( it ), true,
                      [&]( StringRef name, auto ) { this->emplace( name, it ); } );
        }

    };

    struct SymtabFinder : ModuleMap< StringRef, ArchiveReader::BitcodeIterator >
    {
        SymtabFinder( ArchiveReader &r ) {
            for ( auto s : r.archive().symbols() ) {
                auto m = s.getMember();
                if ( !m )
                    _throwLLVMError( m.getError() );
                this->emplace( s.getName(), r.iterator( m.get() ) );
            }
        }
    };

    template< typename Finder >
    void linkArchive( Finder finder ) {
        while ( true ) {
            auto toLink = finder.findFirst( _undefined );
            if ( !toLink )
                return; // no more symbols needed from this archive
            link( std::move( toLink ) );
        }
    }

    template< typename It >
    void linkArchive( It begin, It end ) {
        linkArchive( ModuleFinder< It >( begin, end ) );
    }

    void linkArchive( std::initializer_list< ::llvm::Module * > mods ) {
        linkArchive( mods.begin(), mods.end() );
    }

    void linkArchive( ArchiveReader &r ) {
        if ( r.archive().hasSymbolTable() )
            linkArchive( SymtabFinder( r ) );
        else
            linkArchive( r.modules_begin(), r.modules_end() );
    }

    Module *get() { return _link->getModule(); }
    const Module *get() const { return _link->getModule(); }

    std::unique_ptr< Module > take() {
        std::unique_ptr< Module > m( get() );
        _link.release();
        return m;
    }

    bool hasModule() { return _link != nullptr; }

  private:
    void _error( const ::llvm::DiagnosticInfo &i ) {
        ::llvm::raw_os_ostream err( std::cerr );
        err << "LLVM Linker error: ";
        ::llvm::DiagnosticPrinterRawOStream printer( err );
        i.print( printer );
        err << "\n";
        err.flush();
    }

    template< typename It >
    static std::unique_ptr< Module > _takeModule( It it ) {
        return std::unique_ptr< Module >( ::llvm::CloneModule( *it ) );
    }

    static std::unique_ptr< Module > _takeModule( ArchiveReader::BitcodeIterator &bcit ) {
        return bcit.take();
    }

    template< typename It >
    static Module &_getModule( It it ) { return **it; }

    static Module &_getModule( ArchiveReader::BitcodeIterator &bcit ) { return *bcit; }

    template< typename Yield >
    static void _symbols( Module &m, bool defined, Yield yield )
    {
        for ( auto &fn : m )
            if ( fn.isDeclaration() == !defined && fn.hasName() )
                yield( fn.getName(), &fn );
        for ( auto &glo : m.getGlobalList() )
            if ( glo.isDeclaration() == !defined && glo.hasName() )
                yield( glo.getName(), &glo );
        for ( auto &ali : m.getAliasList() )
            if ( ali.isDeclaration() == !defined && ali.hasName() )
                yield( ali.getName(), &ali );
    }

    void _updateUndefined() {
        _undefined.clear();
        _symbols( *get(), false, [this]( StringRef name, auto ) { _undefined.emplace_back( name ); } );
    }

    friend brick::t_llvm::LinkerTest;

    std::unique_ptr< ::llvm::Linker > _link;
    std::vector< StringRef > _undefined;
};
}

namespace t_llvm {

using namespace brick::llvm;
using ::llvm::cast;
using ::llvm::dyn_cast;
using ::llvm::isa;
using ::llvm::Module;
using ::llvm::StringRef;
using ::llvm::Type;
using ::llvm::FunctionType;
using ::llvm::Function;
using ::llvm::Module;

struct LinkerTest {

    ::llvm::LLVMContext ctx;

    Function *declareFun( Module &m, StringRef name ) {
        auto *fun = cast< Function >( m.getOrInsertFunction( name,
                        FunctionType::get( Type::getInt32Ty( ctx ), false ) ) );
        ASSERT( fun->isDeclaration() );
        return fun;
    }

    Function *defineFun( Module &m, StringRef name ) {
        auto *fun = declareFun( m, name );
        auto bb = ::llvm::BasicBlock::Create( ctx, "entry", fun );
        ::llvm::IRBuilder<> irb( bb );
        irb.CreateRet( irb.getInt32( 0 ) );
        ASSERT( !fun->isDeclaration() );
        return fun;
    }

    template< typename X >
    static int count( X &x, StringRef ref ) {
        return std::count( x.begin(), x.end(), ref );
    }

    TEST(undefined_single) {
        auto m = std::make_unique< ::llvm::Module >( "test.ll", ctx );
        declareFun( *m, "fun" );

        Linker link;
        link.link( std::move( m ) );
        ASSERT_EQ( link._undefined.size(), 1 );
        ASSERT_EQ( count( link._undefined, "fun" ), 1 );
    }

    TEST(undefined_two) {
        auto m = std::make_unique< ::llvm::Module >( "test.ll", ctx );
        declareFun( *m, "fun" );

        auto m2 = std::make_unique< ::llvm::Module >( "test2.ll", ctx );
        defineFun( *m2, "fun" );
        declareFun( *m2, "foo" );

        Linker link;
        link.link( std::move( m ) );
        ASSERT_EQ( link._undefined.size(), 1 );
        ASSERT_EQ( count( link._undefined, "fun" ), 1 );
        link.link( std::move( m2 ) );
        ASSERT_EQ( link._undefined.size(), 1 );
        ASSERT_EQ( count( link._undefined, "foo" ), 1 );
    }

    template< typename Doit, typename Check >
    void archive_simple( Doit doit, Check check ) {
        auto ma = std::make_unique< Module >( "a", ctx );
        auto mb = std::make_unique< Module >( "b", ctx );
        auto mmain = std::make_unique< Module >( "main", ctx );

        defineFun( *mmain, "main" );
        declareFun( *mmain, "foo" );

        defineFun( *ma, "bar" );
        defineFun( *ma, "baz" );

        defineFun( *mb, "foo" );
        defineFun( *mb, "foobar" );

        {
            Linker link;
            link.link( std::move( mmain ) );
            doit( link, ma.get(), mb.get(), link.get() );
            ASSERT_EQ( link._undefined.size(), 0 );
            check( link );
        }
    }

    static void checkALinked( Linker &link ) {
        ASSERT( link.get()->getFunction( "main" ) );
        ASSERT( !link.get()->getFunction( "main" )->isDeclaration() );
        ASSERT( link.get()->getFunction( "foo" ) );
        ASSERT( !link.get()->getFunction( "foo" )->isDeclaration() );
        ASSERT( link.get()->getFunction( "foobar" ) );
        ASSERT( !link.get()->getFunction( "foobar" )->isDeclaration() );
        ASSERT( !link.get()->getFunction( "bar" ) );
        ASSERT( !link.get()->getFunction( "baz" ) );
    }

    TEST(archive_simple_a_in_ab) {
        archive_simple( []( auto &link, auto ma, auto mb, auto ) { link.linkArchive( { ma, mb } ); },
                        checkALinked );
    }

    TEST(archive_simple_a_in_ba) {
        archive_simple( []( auto &link, auto ma, auto mb, auto ) { link.linkArchive( { mb, ma } ); },
                        checkALinked );
    }

    TEST(archive_simple_both) {
        archive_simple( [this]( auto &link, auto ma, auto mb, auto mmain ) {
                declareFun( *mmain, "baz" );
                link.linkArchive( { ma, mb } );
            }, []( auto &link ) {
                ASSERT( link.get()->getFunction( "main" ) );
                ASSERT( !link.get()->getFunction( "main" )->isDeclaration() );
                ASSERT( link.get()->getFunction( "foo" ) );
                ASSERT( !link.get()->getFunction( "foo" )->isDeclaration() );
                ASSERT( link.get()->getFunction( "foobar" ) );
                ASSERT( !link.get()->getFunction( "foobar" )->isDeclaration() );
                ASSERT( link.get()->getFunction( "bar" ) );
                ASSERT( !link.get()->getFunction( "bar" )->isDeclaration() );
                ASSERT( link.get()->getFunction( "baz" ) );
                ASSERT( !link.get()->getFunction( "baz" )->isDeclaration() );
            } );
    }

    TEST(ar_archive) {

        fs::TempDir dir( ".brick-llvm-test-XXXXXX", fs::AutoDelete::Yes, fs::UseSystemTemp::Yes );
        auto ma = std::make_unique< Module >( "a", ctx );
        auto mb = std::make_unique< Module >( "b", ctx );
        declareFun( *ma, "foo" );
        declareFun( *mb, "bar" );

        auto name = fs::joinPath( dir, "test.a" );
        auto il = { std::move( ma ), std::move( mb ) };
        writeArchive( il.begin(), il.end(), name );

        ArchiveReader reader( name );
        int count = 0;
        for ( auto &m : reader.modules() ) {
            ++count;
            ASSERT( count != 1 || m.getFunction( "foo" ) );
            ASSERT( count != 2 || m.getFunction( "bar" ) );
        }
        ASSERT_EQ( count, 2 );
    }
};

}

}

#endif // BRICKS_LLVM_LINK_STRIP_H

#endif

// vim: syntax=cpp tabstop=4 shiftwidth=4 expandtab ft=cpp
