Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ dist/
.idea/
.vscode/
*.DS_Store

# python venv
.venv/
83 changes: 78 additions & 5 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,37 @@ import { Map, Set } from 'immutable'
import type { DataType, Network, TaskProvider } from "@epfml/discojs";
import { defaultTasks } from '@epfml/discojs'

interface BenchmarkArguments {
type AggregationStrategy = "mean" | "byzantine" | "secure";

function parseAggregator(raw: string): AggregationStrategy{
if (raw === "mean" || raw == "byzantine" || raw == "secure")
return raw;
else
throw new Error(`Aggregator ${raw} is not supported.`);
}

export interface BenchmarkArguments {
provider: TaskProvider<DataType, Network>;
testID: string
numberOfUsers: number
epochs: number
roundDuration: number
batchSize: number
validationSplit: number

// DP
epsilon?: number
delta?: number
dpDefaultClippingRadius?: number
// Aggregator
aggregator: AggregationStrategy
// Byzantine aggregator
clippingRadius?: number
maxIterations?: number
beta?: number
// Secure aggregator
maxShareValue?: number

save: boolean
host: URL
}
Expand All @@ -27,15 +48,13 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'

const unsafeArgs = parse<BenchmarkUnsafeArguments>(
{
testID: { type: String, alias: 'i', description: 'ID of the testcase' },
task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined},
delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
Expand All @@ -44,6 +63,22 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
defaultValue: new URL("http://localhost:8080"),
},

// Aggregator
aggregator: { type: parseAggregator, description: 'Type of weight aggregator', defaultValue: 'mean' },

// Byzantine aggregator
clippingRadius: { type: Number, description: "Clipping radius for centered clipping", optional: true },
maxIterations: { type: Number, description: "Maximum centered clipping iterations", optional: true },
beta: { type: Number, description: "Momentum coefficient to smooth the aggregation over multiple rounds", optional: true },

// Secure aggregator
maxShareValue: { type: Number, description: "Maximum absolute value over all the weights", optional: true },

// Differential Privacy
epsilon: { type: Number, description: 'Privacy budget', optional: true, defaultValue: undefined},
delta: { type: Number, description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
dpDefaultClippingRadius: {type: Number, description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},

help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }
},
{
Expand Down Expand Up @@ -88,6 +123,44 @@ export const args: BenchmarkArguments = {
task.trainingInformation.epochs = unsafeArgs.epochs;
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;

const {aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs;

// For aggregators
if (aggregator !== undefined)
task.trainingInformation.aggregationStrategy = aggregator;

// For byzantine aggregator
if (
clippingRadius !== undefined &&
maxIterations !== undefined &&
beta !== undefined
){
if (task.trainingInformation.scheme === "local")
throw new Error("Byzantine aggregator is not supported for local training");
if (task.trainingInformation.aggregationStrategy !== "byzantine")
throw new Error("Byzantine parameters can be set only when aggregationStrategy is byzantine");

task.trainingInformation.privacy = {
...task.trainingInformation.privacy,
byzantineFaultTolerance: {
clippingRadius,
maxIterations,
beta,
},
};
}

// For secure aggregator
if (maxShareValue !== undefined){

if (task.trainingInformation.scheme !== "decentralized")
throw new Error("Secure aggation is only supported for decentralized laerning")
if (task.trainingInformation.aggregationStrategy !== "secure")
throw new Error("maxShareValue can be set when aggregationStrategy is secure");

task.trainingInformation.maxShareValue = maxShareValue;
}

// For DP
const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs;

Expand All @@ -102,7 +175,7 @@ export const args: BenchmarkArguments = {
const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1;

// for the case where privacy parameters are not defined in the default tasks
task.trainingInformation.privacy ??= {}
task.trainingInformation.privacy ??= {};
task.trainingInformation.privacy.differentialPrivacy = {
clippingRadius: defaultRadius,
epsilon: epsilon,
Expand Down
86 changes: 71 additions & 15 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ import "@tensorflow/tfjs-node"

import { List, Range } from 'immutable'
import fs from 'node:fs/promises'
import { createWriteStream } from "node:fs";
import path from "node:path";

import type {
Dataset,
DataFormat,
DataType,
RoundLogs,
SummaryLogs,
Task,
TaskProvider,
Network,
Expand All @@ -17,49 +19,103 @@ import { Disco, aggregator as aggregators, client as clients } from '@epfml/disc

import { getTaskData } from './data.js'
import { args } from './args.js'
import { makeUserLogFile } from "./user_log.js";
import type { UserLogFile } from "./user_log.js";

// Array.fromAsync not yet widely used (2024)
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
const ret: T[] = [];
for await (const e of iter) ret.push(e);
return ret;
}

async function runUser<D extends DataType, N extends Network>(
task: Task<D, N>,
url: URL,
data: Dataset<DataFormat.Raw[D]>,
): Promise<List<RoundLogs>> {
userIndex: number,
numberOfUsers: number,
): Promise<List<SummaryLogs>> {
// cast as typescript isn't good with generics
const trainingScheme = task.trainingInformation.scheme as N
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, client, { scheme: trainingScheme });

const logs = List(await arrayFromAsync(disco.trainByRound(data)));
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish
await disco.close();
return logs;
const dir = path.join(".", `${args.testID}`);
await fs.mkdir(dir, { recursive: true });
const streamPath = path.join(dir, `client${userIndex}_local_log.ndjson`);

const finalLog: SummaryLogs[] = [];
// create a write stream that saves learning logs during the train
let ndjsonStream: ReturnType<typeof createWriteStream> | null = null;

if (args.save){
ndjsonStream = createWriteStream(streamPath, {flags: "w"});
}

Comment thread
JulienVig marked this conversation as resolved.
Outdated
try{
for await (const log of disco.trainSummary(data)){
finalLog.push(log);

if (ndjsonStream){
ndjsonStream.write(JSON.stringify(log) + "\n");
}
}

await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish

// saving the entire per-user logs
if (args.save) {
const finalPath = path.join(dir, `client${userIndex}_local_log.json`);

const userLog: UserLogFile = makeUserLogFile(task, numberOfUsers, userIndex, client.ownId, finalLog);

await fs.writeFile(finalPath, JSON.stringify(userLog, null, 2));
}

return List(finalLog);
}catch(err){
console.error(`Run user failed for client ${userIndex}: `, err);
throw err;
}finally{
try{
if (ndjsonStream){
ndjsonStream.end();

await new Promise<void>((resolve, reject) => {
ndjsonStream.once("finish", resolve);
ndjsonStream.once("error", reject);
});
}
}catch(err){
console.error(`failed to close log stream for client ${userIndex}: `, err);
}

try{
await disco.close();
}catch(err){
console.error(`failed to close disco for client ${userIndex}: `, err);
}
}
}

async function main<D extends DataType, N extends Network>(
provider: TaskProvider<D, N>,
numberOfUsers: number,
): Promise<void> {
const task = await provider.getTask();
console.log(`Test ID: ${args.testID}`)
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
)
const logs = await Promise.all(
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
dataSplits.map((data, i) => runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>, i, numberOfUsers))
)

if (args.save) {
const fileName = `${task.id}_${numberOfUsers}users.csv`;
await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
const dir = path.join(".", `${args.testID}`, `${task.id}`);
await fs.mkdir(dir, { recursive: true });

const filePath = path.join(dir, `${task.id}_${numberOfUsers}users.json`);
await fs.writeFile(filePath, JSON.stringify(logs, null, 2));
}
}

Expand Down
Loading