Skip to content
8 changes: 6 additions & 2 deletions src/binaryen-c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5888,9 +5888,13 @@ void BinaryenSetTrapsNeverHappen(bool on) {
globalPassOptions.trapsNeverHappen = on;
}

bool BinaryenGetClosedWorld(void) { return globalPassOptions.closedWorld; }
bool BinaryenGetClosedWorld(void) {
return globalPassOptions.worldMode == WorldMode::Closed;
}

void BinaryenSetClosedWorld(bool on) { globalPassOptions.closedWorld = on; }
void BinaryenSetClosedWorld(bool on) {
globalPassOptions.worldMode = on ? WorldMode::Closed : WorldMode::Open;
}

bool BinaryenGetLowMemoryUnused(void) {
return globalPassOptions.lowMemoryUnused;
Expand Down
133 changes: 95 additions & 38 deletions src/ir/module-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,16 @@ struct CodeScanner : PostWalker<CodeScanner> {
};

void classifyTypeVisibility(Module& wasm,
InsertOrderedMap<HeapType, HeapTypeInfo>& types);
InsertOrderedMap<HeapType, HeapTypeInfo>& types,
WorldMode worldMode);

} // anonymous namespace

InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo(
Module& wasm, TypeInclusion inclusion, VisibilityHandling visibility) {
InsertOrderedMap<HeapType, HeapTypeInfo>
collectHeapTypeInfo(Module& wasm,
WorldMode worldMode,
TypeInclusion inclusion,
VisibilityHandling visibility) {
// Collect module-level info.
TypeInfos info;
CodeScanner(wasm, info).walkModuleCode(&wasm);
Expand Down Expand Up @@ -593,7 +597,7 @@ InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo(
}

if (visibility == VisibilityHandling::FindVisibility) {
classifyTypeVisibility(wasm, info.info);
classifyTypeVisibility(wasm, info.info, worldMode);
}

return std::move(info.info);
Expand All @@ -602,8 +606,9 @@ InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo(
namespace {

void classifyTypeVisibility(Module& wasm,
InsertOrderedMap<HeapType, HeapTypeInfo>& types) {
for (auto type : getPublicHeapTypes(wasm)) {
InsertOrderedMap<HeapType, HeapTypeInfo>& types,
WorldMode worldMode) {
for (auto type : getPublicHeapTypes(wasm, worldMode)) {
if (auto it = types.find(type); it != types.end()) {
it->second.visibility = Visibility::Public;
}
Expand All @@ -615,6 +620,64 @@ void classifyTypeVisibility(Module& wasm,
}
}

// Collects all heap types transitively reachable from a root set of types.
// Options are provided to customize the traversal:
// - `includeSupertypes`: if true, declared supertypes are also traversed.
// - `includeRecGroups`: if true, all types in the same recursion group
// are also traversed.
std::vector<HeapType>
getTransitivelyReachable(const std::vector<HeapType>& roots,
bool includeSupertypes,
bool includeRecGroups) {
std::vector<HeapType> result;
std::vector<HeapType> worklist;
std::unordered_set<HeapType> seen;
std::unordered_set<RecGroup> seenRecGroups;

auto note = [&](HeapType type) {
if (type.isBasic()) {
if (seen.insert(type).second) {
result.push_back(type);
}
return;
}

if (includeRecGroups) {
auto group = type.getRecGroup();
if (seenRecGroups.insert(group).second) {
for (auto member : group) {
result.push_back(member);
worklist.push_back(member);
}
}
} else {
if (seen.insert(type).second) {
result.push_back(type);
worklist.push_back(type);
}
}
};

for (auto type : roots) {
note(type);
}

while (!worklist.empty()) {
auto curr = worklist.back();
worklist.pop_back();
std::optional<HeapType> super =
includeSupertypes ? std::nullopt : curr.getDeclaredSuperType();
for (auto t : curr.getReferencedHeapTypes()) {
if (super && t == *super) {
continue;
}
note(t);
}
}

return result;
}

void setIndices(IndexedHeapTypes& indexedTypes) {
for (Index i = 0; i < indexedTypes.types.size(); i++) {
indexedTypes.indices[indexedTypes.types[i]] = i;
Expand All @@ -624,7 +687,7 @@ void setIndices(IndexedHeapTypes& indexedTypes) {
} // anonymous namespace

std::vector<HeapType> collectHeapTypes(Module& wasm) {
auto info = collectHeapTypeInfo(wasm);
auto info = collectHeapTypeInfo(wasm, WorldMode::Open);
std::vector<HeapType> types;
types.reserve(info.size());
for (auto& [type, _] : info) {
Expand All @@ -633,27 +696,16 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) {
return types;
}

std::vector<HeapType> getPublicHeapTypes(Module& wasm) {
// Look at the types of imports as exports to get an initial set of public
// types, then traverse the types used by public types and collect the
// transitively reachable public types as well.
std::vector<HeapType> workList;
std::unordered_set<RecGroup> publicGroups;

// The collected types.
std::vector<HeapType> getExposedPublicHeapTypes(Module& wasm) {
// Look at the types of imports and exports to get an initial set of public
// types.
std::vector<HeapType> publicTypes;
std::unordered_set<HeapType> seenTypes;

auto notePublic = [&](HeapType type) {
if (type.isBasic()) {
return;
if (seenTypes.insert(type).second) {
publicTypes.push_back(type);
}
auto group = type.getRecGroup();
if (!publicGroups.insert(group).second) {
// The groups in this type have already been marked public.
return;
}
publicTypes.insert(publicTypes.end(), group.begin(), group.end());
workList.insert(workList.end(), group.begin(), group.end());
};

ModuleUtils::iterImportedTags(wasm, [&](Tag* tag) { notePublic(tag->type); });
Expand Down Expand Up @@ -710,24 +762,28 @@ std::vector<HeapType> getPublicHeapTypes(Module& wasm) {
notePublic(type);
}

// Find all the other public types reachable from directly publicized types.
while (!workList.empty()) {
auto curr = workList.back();
workList.pop_back();
for (auto t : curr.getReferencedHeapTypes()) {
notePublic(t);
return publicTypes;
}

std::vector<HeapType> getPublicHeapTypes(Module& wasm, WorldMode worldMode) {
auto directlyExposed = getExposedPublicHeapTypes(wasm);
auto transitivelyExposed = getTransitivelyReachable(
directlyExposed, /*includeSupertypes=*/true, /*includeRecGroups=*/true);
std::vector<HeapType> publicTypes;
publicTypes.reserve(transitivelyExposed.size());
for (auto type : transitivelyExposed) {
if (!type.isBasic()) {
publicTypes.push_back(type);
}
}

// TODO: In an open world, we need to consider subtypes of public types public
// as well, or potentially even consider all types to be public unless
// otherwise annotated.
return publicTypes;
}

std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
auto info = collectHeapTypeInfo(
wasm, TypeInclusion::UsedIRTypes, VisibilityHandling::FindVisibility);
std::vector<HeapType> getPrivateHeapTypes(Module& wasm, WorldMode worldMode) {
auto info = collectHeapTypeInfo(wasm,
worldMode,
TypeInclusion::UsedIRTypes,
VisibilityHandling::FindVisibility);
std::vector<HeapType> types;
types.reserve(info.size());
for (auto& [type, typeInfo] : info) {
Expand All @@ -739,7 +795,8 @@ std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
}

IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
auto counts = collectHeapTypeInfo(wasm, TypeInclusion::BinaryTypes);
auto counts =
collectHeapTypeInfo(wasm, WorldMode::Open, TypeInclusion::BinaryTypes);

// Collect the rec groups.
std::unordered_map<RecGroup, size_t> groupIndices;
Expand Down
18 changes: 12 additions & 6 deletions src/ir/module-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,20 +472,26 @@ struct HeapTypeInfo {

InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo(
Module& wasm,
WorldMode worldMode,
TypeInclusion inclusion = TypeInclusion::AllTypes,
VisibilityHandling visibility = VisibilityHandling::NoVisibility);

// Helper function for collecting all the non-basic heap types used in the
// module, i.e. the types that would appear in the type section.
std::vector<HeapType> collectHeapTypes(Module& wasm);

// Collect all the heap types visible on the module boundary that cannot be
// changed. TODO: For open world use cases, this needs to include all subtypes
// of public types as well.
std::vector<HeapType> getPublicHeapTypes(Module& wasm);
// Get the types directly made public by imported or exported module items. For
// example, the types of imported or exported globals or functions, but not
// other types reachable from those types. Includes abstract heap types.
std::vector<HeapType> getExposedPublicHeapTypes(Module& wasm);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment, I think - how does it differ from getPublicHeapTypes?


// getHeapTypes - getPublicHeapTypes
std::vector<HeapType> getPrivateHeapTypes(Module& wasm);
// Collect all the defined heap types visible on the module boundary that cannot
// be changed, e.g. the defined types from getExposedPublicHeapTypes and those
// they reach.
std::vector<HeapType> getPublicHeapTypes(Module& wasm, WorldMode worldMode);

// All the defined heap types that are not public.
std::vector<HeapType> getPrivateHeapTypes(Module& wasm, WorldMode worldMode);

struct IndexedHeapTypes {
std::vector<HeapType> types;
Expand Down
12 changes: 6 additions & 6 deletions src/ir/possible-contents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ struct InfoCollector
SignatureResultLocation{func->type.getHeapType(), i}});
}

if (!options.closedWorld) {
if (options.worldMode == WorldMode::Open) {
info.calledFromOutside.insert(curr->func);
}
}
Expand Down Expand Up @@ -1711,7 +1711,7 @@ void TNHOracle::scan(Function* func,
void visitCallRef(CallRef* curr) {
// We can only optimize call_ref in closed world, as otherwise the
// call can go somewhere we can't see.
if (options.closedWorld) {
if (options.worldMode == WorldMode::Closed) {
info.callRefs.push_back(curr);
}
}
Expand Down Expand Up @@ -1834,7 +1834,7 @@ void TNHOracle::infer() {
// that type or a subtype, i.e., might be called when that type is seen in a
// call_ref target.
std::unordered_map<HeapType, std::vector<Function*>> typeFunctions;
if (options.closedWorld) {
if (options.worldMode == WorldMode::Closed) {
for (auto& func : wasm.functions) {
auto type = func->type;
auto& info = map[wasm.getFunction(func->name)];
Expand Down Expand Up @@ -1895,7 +1895,7 @@ void TNHOracle::infer() {
// We should only get here in a closed world, in which we know which
// functions might be called (the scan phase only notes callRefs if we are
// in fact in a closed world).
assert(options.closedWorld);
assert(options.worldMode == WorldMode::Closed);

auto iter = typeFunctions.find(targetType.getHeapType());
if (iter == typeFunctions.end()) {
Expand Down Expand Up @@ -2535,8 +2535,8 @@ Flower::Flower(Module& wasm, const PassOptions& options)
}

// In open world, public heap types may be written to from the outside.
if (!options.closedWorld) {
for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) {
if (options.worldMode == WorldMode::Open) {
for (auto type : ModuleUtils::getPublicHeapTypes(wasm, options.worldMode)) {
if (type.isStruct()) {
auto& fields = type.getStruct().fields;
for (Index i = 0; i < fields.size(); i++) {
Expand Down
3 changes: 2 additions & 1 deletion src/ir/type-updating.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

namespace wasm {

GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm)
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm, WorldMode worldMode)
: wasm(wasm), publicGroups(wasm.features) {
// Find the heap types that are not publicly observable. Even in a closed
// world scenario, don't modify public types because we assume that they may
// be reflected on or used for linking. Figure out where each private type
// will be located in the builder.
typeInfo = ModuleUtils::collectHeapTypeInfo(
wasm,
worldMode,
ModuleUtils::TypeInclusion::UsedIRTypes,
ModuleUtils::VisibilityHandling::FindVisibility);

Expand Down
18 changes: 11 additions & 7 deletions src/ir/type-updating.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class GlobalTypeRewriter {
// private types do not conflict with public types.
UniqueRecGroups publicGroups;

GlobalTypeRewriter(Module& wasm);
GlobalTypeRewriter(Module& wasm, WorldMode worldMode);
virtual ~GlobalTypeRewriter() {}

// Main entry point. This performs the entire process of creating new heap
Expand Down Expand Up @@ -427,7 +427,9 @@ class GlobalTypeRewriter {

// Helper for the repeating pattern of just updating Signature types using a
// map of old heap type => new Signature.
static void updateSignatures(const SignatureUpdates& updates, Module& wasm) {
static void updateSignatures(const SignatureUpdates& updates,
Module& wasm,
WorldMode worldMode) {
if (updates.empty()) {
return;
}
Expand All @@ -436,8 +438,10 @@ class GlobalTypeRewriter {
const SignatureUpdates& updates;

public:
SignatureRewriter(Module& wasm, const SignatureUpdates& updates)
: GlobalTypeRewriter(wasm), updates(updates) {
SignatureRewriter(Module& wasm,
const SignatureUpdates& updates,
WorldMode worldMode)
: GlobalTypeRewriter(wasm, worldMode), updates(updates) {
update();
}

Expand All @@ -448,7 +452,7 @@ class GlobalTypeRewriter {
sig.results = getTempType(iter->second.results);
}
}
} rewriter(wasm, updates);
} rewriter(wasm, updates, worldMode);
}

protected:
Expand All @@ -473,8 +477,8 @@ class TypeMapper : public GlobalTypeRewriter {

const TypeUpdates& mapping;

TypeMapper(Module& wasm, const TypeUpdates& mapping)
: GlobalTypeRewriter(wasm), mapping(mapping) {}
TypeMapper(Module& wasm, const TypeUpdates& mapping, WorldMode worldMode)
: GlobalTypeRewriter(wasm, worldMode), mapping(mapping) {}

void map() {
// Update the internals of types (struct fields, signatures, etc.) to
Expand Down
Loading
Loading