实现版本:Spring 2025

实现时间:2025/7/25

MIT 6.5840 Distributed System Lab 3 个人踩坑记录。

3A

重置选举时间错误,这个在3A测试并没有问题,但在3C会发现 TestFigure83CTestFigure8Unreliable3C 可能会出现最后一个ts.one 超时。

重置选举时间只有以下三个情况:

  • AppendEntries (任期至少一样新)。
  • 选举超时。
  • RequestVote进行投票(未投票不需要重置)。

个人在 reply.Term > currentTerm 转换状态时重置了选举时间导致了错误,修改后3C测试 1000 次无误。

如果你要使用timer进行重置,请注意timer的正确用法。

个人重置:

// 检测
for !rf.killed() {
    select {
    case <-rf.electionTimer.C:
        // election
        rf.election()
    case <-rf.resetCh:
        // reset timer. timer (Stop/Reset/<-C) must be executed in the same goroutine.

        // before Go 1.23
        // if !rf.electionTimer.Stop() {
        //     select {
        //     case <-rf.electionTimer.C:
        //     default:
        //     }
        // }

        // Go 1.23+
        rf.electionTimer.Stop()
        rf.electionTimer.Reset(randElectionTimeout())
    }
}

// send reset signal.
// resetCh size should greater than 0.
// resetCh is full, do not send signal.
func (rf *Raft) sendResetSignal() {
    select {
    case rf.resetCh <- struct{}{}:
    default:
    }
}

resetCh:       make(chan struct{}, 1), // must > 0

3B

  1. AppendEntries匹配后直接截断后追加Entries。
log.resize(args.PrevLogIndex)
log.appendLogs(args.Entries)

思考这样一个情况:

leader:
log: [1]
leader send AppendEntries -> A1
log: [1, 2]
leader send AppendEntries -> A2

follower:
先收到 A2:
log -> [1, 2]
再收到 A1:
log -> [1]

这样就会导致follower的log回退,进而导致commit出问题。

应按照论文上说的追加不存在的日志。

conflictIndex := -1
for i := 0; i < len(args.Entries); i++ {
    absIndex := args.PrevLogIndex + i + 1
    if absIndex >= rf.log.getAbsLogSize() {
        break
    }

    if rf.log.getAbsLog(absIndex).Term != args.Entries[i].Term {
        conflictIndex = i
        break
    }
}

if conflictIndex >= 0 {
    rf.log.resizeAbs(args.PrevLogIndex + conflictIndex + 1)
    rf.log.appendLogs(args.Entries[conflictIndex:])
    rf.persist()
} else {
    existingCount := rf.log.getAbsLogSize() - args.PrevLogIndex - 1
    toAppend := len(args.Entries) - existingCount
    if toAppend > 0 {
        rf.log.appendLogs(args.Entries[existingCount:])
        rf.persist()
    }
}
  1. 不要在往 applyCh 放入时加锁,会导致死锁。
rf.mu.Lock()
applyCh <- msg
rf.mu.UnLock()

思考这样一个情况:

下层 Raft:
rf.mu.Lock()   // 持有锁A
applyCh <- msg // 发生阻塞

上层 Service
未读取 applyCh 的内容
上层加锁B
调用 Start(cmd), 这时候会卡住,等待锁A释放。
因为上层Service在等待锁A释放,持有锁B无法处理applyCh的内容,而下层Raft在等待上层读取applyCh的内容,造成循环等待,发生死锁。

个人建议利用buffer保存需要apply的数据,解锁后将buffer的内容apply。

3C

3C会发现很多3A+3B出现的错误。

  1. TestFigure83CTestFigure8Unreliable3C 可能出现 apply error: apply out of order ,这个Bug测试十几次才会出现一次。

debug发现kill后重启raft,apply协程还存在,这时候旧的协程往旧的applyCh传数据会出现错误。

分析原因:rf.applyCh <- msg 阻塞导致未及时退出,apply前要检测是否kill。

// applyCh, do not lock
for _, msg := range batch {
    // has killed, do not send old data to old chan!
    if rf.killed() {
        return
    }
    rf.applyCh <- msg
}

3D

发现测试框架的Bug,会导致DATA RACE。

server.go

func newRfsrv
    go s.applierSnap(applyCh)

func (rs *rfsrv) Kill()
    rs.mu.Lock()
    rs.raft = nil // tester will call Kill() on rs.raft
    rs.mu.Unlock()

func (rs *rfsrv) applierSnap(applyCh chan raftapi.ApplyMsg)
    if rs.raft == nil {
        return // ???
    }
    rs.raft.Snapshot(m.CommandIndex, w.Bytes())

可以看到Kill处修改raft加了锁,而在applierSnap两处访问raft都未加锁,有可能在rs.raft.Snapshot这里出现DATA RACE或者nil产生panic。

修改如下(只需要修改applierSnap):

// 删除 
if rs.raft == nil {
    return // ???
}

// rs.raft.Snapshot 加锁访问
rs.mu.Lock()
if rs.raft == nil {
    rs.mu.Unlock()
    return
} else {
    rs.raft.Snapshot(m.CommandIndex, w.Bytes())
    rs.mu.Unlock()
}
  1. 误以为leader运用快照的时间一定比follower早。

在AppendEntris时未判定args.PrevLogIndex 是否小于 LastIncludedIndex ,导致访问出现panic。

LastIncludedIndex 由快照产生并修改,误以为leader的LastIncludedIndex 一定比LastIncludedIndex 更新。

增加相关判定:

// PrevLogIndex less than follower LastIncludedIndex, (nextIndex[server] -> XLen)
if args.PrevLogIndex < rf.log.startAt() {
    return
}
  1. InstallSnapshot 是需要判定 args.LastIncludedIndex <= rf.lastApplied,如果已经运用到状态机就不需要在运用快照。

这个未判定似乎不会导致Fail,但可以减少不必要的操作。

  1. raft 重启后需要向上层传递当前快照,同时需要设置lastApplied为LastIncludedIndex。

个人建议:

Log 可以单独写一个类处理, 而不是直接在raft内维护log和lastIncludedIndex,这样只需要调用对应的函数即可。

type LogItem struct {
    Command any
    Term    int
}

type Log struct {
    Log               []LogItem
    LastIncludedIndex int
}

func NewLog() Log {
    return Log{
        Log:               []LogItem{{Command: nil, Term: 0}},
        LastIncludedIndex: 0,
    }
}

// get log at absolute index.
func (l *Log) getAbsLog(index int) LogItem {
    return l.Log[index-l.LastIncludedIndex]
}

// get log[start:end] by absolute index
func (l *Log) getAbsLogs(start int, end int) []LogItem {
    log := make([]LogItem, end-start)
    copy(log, l.Log[start-l.LastIncludedIndex:end-l.LastIncludedIndex])
    return log
}

// get log absolute size
func (l *Log) getAbsLogSize() int {
    return len(l.Log) + l.LastIncludedIndex
}

// resize log at absolute size
func (l *Log) resizeAbs(size int) {
    l.Log = l.Log[:size-l.LastIncludedIndex]
}

// reset at absolute index
func (l *Log) resetAbs(index int, term int) {
    l.Log = append([]LogItem{}, LogItem{Term: term})
    l.LastIncludedIndex = index
}

// compact log at absolute index, log -> log[index:], create new slice
func (l *Log) compact(index int) {
    l.Log = append([]LogItem{}, l.Log[index-l.LastIncludedIndex:]...)
    l.LastIncludedIndex = index
}

// get last Log
func (l *Log) lastLog() LogItem {
    return l.Log[len(l.Log)-1]
}

// append log
func (l *Log) appendLog(log LogItem) {
    l.Log = append(l.Log, log)
}

// append logs
func (l *Log) appendLogs(logs []LogItem) {
    l.Log = append(l.Log, logs...)
}

func (l *Log) startAt() int {
    return l.LastIncludedIndex
}

func (l *Log) findFirstAbsIndexOfTerm(term int) int {
    for i := l.startAt(); i < l.getAbsLogSize(); i++ {
        if l.getAbsLog(i).Term == term {
            return i
        }
    }
    return l.startAt()
}

dtest

测试脚本:

#!/usr/bin/env python

import itertools
import math
import signal
import subprocess
import tempfile
import shutil
import time
import os
import sys
import datetime
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Dict, DefaultDict, Tuple

import typer
import rich
from rich import print
from rich.table import Table
from rich.progress import (
    Progress,
    TimeElapsedColumn,
    TimeRemainingColumn,
    TextColumn,
    BarColumn,
    SpinnerColumn,
)
from rich.live import Live
from rich.panel import Panel
from rich.traceback import install

install(show_locals=True)


@dataclass
class StatsMeter:
    """
    Auxiliary classs to keep track of online stats including: count, mean, variance
    Uses Welford's algorithm to compute sample mean and sample variance incrementally.
    https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm
    """

    n: int = 0
    mean: float = 0.0
    S: float = 0.0

    def add(self, datum):
        self.n += 1
        delta = datum - self.mean
        # Mk = Mk-1+ (xk – Mk-1)/k
        self.mean += delta / self.n
        # Sk = Sk-1 + (xk – Mk-1)*(xk – Mk).
        self.S += delta * (datum - self.mean)

    @property
    def variance(self):
        return self.S / self.n

    @property
    def std(self):
        return math.sqrt(self.variance)


def print_results(results: Dict[str, Dict[str, StatsMeter]], timing=False):
    table = Table(show_header=True, header_style="bold")
    table.add_column("Test")
    table.add_column("Failed", justify="right")
    table.add_column("Total", justify="right")
    if not timing:
        table.add_column("Time", justify="right")
    else:
        table.add_column("Real Time", justify="right")
        table.add_column("User Time", justify="right")
        table.add_column("System Time", justify="right")

    for test, stats in results.items():
        if stats["completed"].n == 0:
            continue
        color = "green" if stats["failed"].n == 0 else "red"
        row = [
            f"[{color}]{test}[/{color}]",
            str(stats["failed"].n),
            str(stats["completed"].n),
        ]
        if not timing:
            row.append(f"{stats['time'].mean:.2f} ± {stats['time'].std:.2f}")
        else:
            row.extend(
                [
                    f"{stats['real_time'].mean:.2f} ± {stats['real_time'].std:.2f}",
                    f"{stats['user_time'].mean:.2f} ± {stats['user_time'].std:.2f}",
                    f"{stats['system_time'].mean:.2f} ± {stats['system_time'].std:.2f}",
                ]
            )
        table.add_row(*row)

    print(table)


def run_test(test: str, race: bool, timing: bool):
    test_cmd = ["go", "test", f"-run={test}"]
    if race:
        test_cmd.append("-race")
    if timing:
        if sys.platform.startswith("linux"):
            test_cmd = ["time", "-p", "-f", "%e real         %U user         %S sys"] + test_cmd
        else:
            test_cmd = ["time"] + test_cmd
    f, path = tempfile.mkstemp()
    start = time.time()
    proc = subprocess.run(test_cmd, stdout=f, stderr=f)
    runtime = time.time() - start
    os.close(f)
    return test, path, proc.returncode, runtime


def last_line(file: str) -> str:
    with open(file, "rb") as f:
        f.seek(-2, os.SEEK_END)
        while f.read(1) != b"\n":
            f.seek(-2, os.SEEK_CUR)
        line = f.readline().decode()
    return line


# fmt: off
def run_tests(
    tests: List[str],
    sequential: bool       = typer.Option(False,  '--sequential',      '-s',    help='Run all test of each group in order'),
    workers: int           = typer.Option(1,      '--workers',         '-p',    help='Number of parallel tasks'),
    iterations: int        = typer.Option(10,     '--iter',            '-n',    help='Number of iterations to run'),
    output: Optional[Path] = typer.Option(None,   '--output',          '-o',    help='Output path to use'),
    verbose: int           = typer.Option(0,      '--verbose',         '-v',    help='Verbosity level', count=True),
    archive: bool          = typer.Option(False,  '--archive',         '-a',    help='Save all logs intead of only failed ones'),
    race: bool             = typer.Option(False,  '--race/--no-race',  '-r/-R', help='Run with race checker'),
    loop: bool             = typer.Option(False,  '--loop',            '-l',    help='Run continuously'),
    growth: int            = typer.Option(10,     '--growth',          '-g',    help='Growth ratio of iterations when using --loop'),
    timing: bool           = typer.Option(False,   '--timing',          '-t',    help='Report timing, only works on macOS and linux'),
    # fmt: on
):

    if output is None:
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        output = Path(timestamp)

    if race:
        print("[yellow]Running with the race detector\n[/yellow]")

    if verbose > 0:
        print(f"[yellow] Verbosity level set to {verbose}[/yellow]")
        os.environ['VERBOSE'] = str(verbose)

    while True:

        total = iterations * len(tests)
        completed = 0

        results = {test: defaultdict(StatsMeter) for test in tests}

        if sequential:
            test_instances = itertools.chain.from_iterable(itertools.repeat(test, iterations) for test in tests)
        else:
            test_instances = itertools.chain.from_iterable(itertools.repeat(tests, iterations))
        test_instances = iter(test_instances)

        total_progress = Progress(
            "[progress.description]{task.description}",
            BarColumn(),
            TimeRemainingColumn(),
            "[progress.percentage]{task.percentage:>3.0f}%",
            TimeElapsedColumn(),
        )
        total_task = total_progress.add_task("[yellow]Tests[/yellow]", total=total)

        task_progress = Progress(
            "[progress.description]{task.description}",
            SpinnerColumn(),
            BarColumn(),
            "{task.completed}/{task.total}",
        )
        tasks = {test: task_progress.add_task(test, total=iterations) for test in tests}

        progress_table = Table.grid()
        progress_table.add_row(total_progress)
        progress_table.add_row(Panel.fit(task_progress))

        with Live(progress_table, transient=True) as live:

            def handler(_, frame):
                live.stop()
                print('\n')
                print_results(results)
                sys.exit(1)

            signal.signal(signal.SIGINT, handler)

            with ThreadPoolExecutor(max_workers=workers) as executor:

                futures = []
                while completed < total:
                    n = len(futures)
                    if n < workers:
                        for test in itertools.islice(test_instances, workers-n):
                            futures.append(executor.submit(run_test, test, race, timing))

                    done, not_done = wait(futures, return_when=FIRST_COMPLETED)

                    for future in done:
                        test, path, rc, runtime = future.result()

                        results[test]['completed'].add(1)
                        results[test]['time'].add(runtime)
                        task_progress.update(tasks[test], advance=1)
                        dest = (output / f"{test}_{completed}.log").as_posix()
                        if rc != 0:
                            print(f"Failed test {test} - {dest}")
                            task_progress.update(tasks[test], description=f"[red]{test}[/red]")
                            results[test]['failed'].add(1)
                        else:
                            if results[test]['completed'].n == iterations and results[test]['failed'].n == 0:
                                task_progress.update(tasks[test], description=f"[green]{test}[/green]")

                        if rc != 0 or archive:
                            output.mkdir(exist_ok=True, parents=True)
                            shutil.copy(path, dest)

                        if timing:
                            line = last_line(path)
                            real, _, user, _, system, _ = line.replace(' '*8, '').split(' ')
                            results[test]['real_time'].add(float(real))
                            results[test]['user_time'].add(float(user))
                            results[test]['system_time'].add(float(system))

                        os.remove(path)

                        completed += 1
                        total_progress.update(total_task, advance=1)

                        futures = list(not_done)

        print_results(results, timing)

        if loop:
            iterations *= growth
            print(f"[yellow]Increasing iterations to {iterations}[/yellow]")
        else:
            break


if __name__ == "__main__":
    typer.run(run_tests)