Cairo Academy: Understanding a Stake Contract

This chapter delves into the basics of a staking contract implemented in Cairo. This contract demonstrates the core functionalities of a basic Stake contract, providing a practical example for learning how to build a basic staking contract on StarkNet.

Purpose and Functionality

The provided Cairo code defines a basic Stake contract. It implements essential Staking Contract functionalities

  • set_reward_amount This function sets the reward amount. This takes a parameter: amount
  • set_reward_duration This function sets the period (time) for the reward to be accessed. This takes parameter: duration
  • stake This function allows the staker to stake his digital assets. This takes parameter: amount
  • withdraw This function allows for stakes to be taken (withdrawn). This takes parameter: amount
  • get_rewards This function computes rewards
  • claim_rewards This function makes claiming of rewards possible.
#![allow(unused)]
fn main() {
use starknet::ContractAddress;
 
#[starknet::interface]
pub trait IStakingContract<TContractState> {
    fn set_reward_amount(ref self: TContractState, amount: u256);
    fn set_reward_duration(ref self: TContractState, duration: u256);
    fn stake(ref self: TContractState, amount: u256);
    fn withdraw(ref self: TContractState, amount: u256);
    fn get_rewards(self: @TContractState, account: ContractAddress) -> u256;
    fn claim_rewards(ref self: TContractState);
}
 
mod Errors {
    pub const NULL_REWARDS: felt252 = 'Reward amount must be > 0';
    pub const NOT_ENOUGH_REWARDS: felt252 = 'Reward amount must be > balance';
    pub const NULL_AMOUNT: felt252 = 'Amount must be > 0';
    pub const NULL_DURATION: felt252 = 'Duration must be > 0';
    pub const UNFINISHED_DURATION: felt252 = 'Reward duration not finished';
    pub const NOT_OWNER: felt252 = 'Caller is not the owner';
    pub const NOT_ENOUGH_BALANCE: felt252 = 'Balance too low';
}
 
#[starknet::contract]
pub mod StakingContract {
    use core::starknet::event::EventEmitter;
    use core::num::traits::Zero;
    use starknet::{ContractAddress, get_caller_address, get_block_timestamp, get_contract_address};
    use openzeppelin_token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait};
    use starknet::storage::{
        Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePointerReadAccess,
        StoragePointerWriteAccess,
    };
 
    #[storage]
    struct Storage {
        pub staking_token: IERC20Dispatcher,
        pub reward_token: IERC20Dispatcher,
        pub owner: ContractAddress,
        pub reward_rate: u256,
        pub duration: u256,
        pub current_reward_per_staked_token: u256,
        pub finish_at: u256,
        pub last_updated_at: u256,
        pub last_user_reward_per_staked_token: Map::<ContractAddress, u256>,
        pub unclaimed_rewards: Map::<ContractAddress, u256>,
        pub total_distributed_rewards: u256,
        pub total_supply: u256,
        pub balance_of: Map::<ContractAddress, u256>,
    }
 
    #[event]
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub enum Event {
        Deposit: Deposit,
        Withdrawal: Withdrawal,
        RewardsFinished: RewardsFinished,
    }
 
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub struct Deposit {
        pub user: ContractAddress,
        pub amount: u256,
    }
 
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub struct Withdrawal {
        pub user: ContractAddress,
        pub amount: u256,
    }
 
    #[derive(Copy, Drop, Debug, PartialEq, starknet::Event)]
    pub struct RewardsFinished {
        pub msg: felt252,
    }
 
    #[constructor]
    fn constructor(
        ref self: ContractState,
        staking_token_address: ContractAddress,
        reward_token_address: ContractAddress,
    ) {
        self.staking_token.write(IERC20Dispatcher { contract_address: staking_token_address });
        self.reward_token.write(IERC20Dispatcher { contract_address: reward_token_address });
 
        self.owner.write(get_caller_address());
    }
 
    #[abi(embed_v0)]
    impl StakingContract of super::IStakingContract<ContractState> {
        fn set_reward_duration(ref self: ContractState, duration: u256) {
            self.only_owner();
 
            assert(duration > 0, super::Errors::NULL_DURATION);
 
            assert(
                self.finish_at.read() < get_block_timestamp().into(),
                super::Errors::UNFINISHED_DURATION,
            );
 
            self.duration.write(duration);
        }
 
        fn set_reward_amount(ref self: ContractState, amount: u256) {
            self.only_owner();
            self.update_rewards(Zero::zero());
 
            assert(amount > 0, super::Errors::NULL_REWARDS);
            assert(self.duration.read() > 0, super::Errors::NULL_DURATION);
 
            let block_timestamp: u256 = get_block_timestamp().into();
 
            let rate = if self.finish_at.read() < block_timestamp {
                amount / self.duration.read()
            } else {
                let remaining_rewards = self.reward_rate.read()
                    * (self.finish_at.read() - block_timestamp);
                (remaining_rewards + amount) / self.duration.read()
            };
 
            assert(
                self.reward_token.read().balance_of(get_contract_address()) >= rate
                    * self.duration.read(),
                super::Errors::NOT_ENOUGH_REWARDS,
            );
 
            self.reward_rate.write(rate);
 
            self.finish_at.write(block_timestamp + self.duration.read());
            self.last_updated_at.write(block_timestamp);
 
            self.total_distributed_rewards.write(0);
        }
 
        fn stake(ref self: ContractState, amount: u256) {
            assert(amount > 0, super::Errors::NULL_AMOUNT);
 
            let user = get_caller_address();
            self.update_rewards(user);
 
            self.balance_of.write(user, self.balance_of.read(user) + amount);
            self.total_supply.write(self.total_supply.read() + amount);
            self.staking_token.read().transfer_from(user, get_contract_address(), amount);
 
            self.emit(Deposit { user, amount });
        }
 
        fn withdraw(ref self: ContractState, amount: u256) {
            assert(amount > 0, super::Errors::NULL_AMOUNT);
 
            let user = get_caller_address();
 
            assert(
                self.staking_token.read().balance_of(user) >= amount,
                super::Errors::NOT_ENOUGH_BALANCE,
            );
 
            self.update_rewards(user);
 
            self.balance_of.write(user, self.balance_of.read(user) - amount);
            self.total_supply.write(self.total_supply.read() - amount);
            self.staking_token.read().transfer(user, amount);
 
            self.emit(Withdrawal { user, amount });
        }
 
        fn get_rewards(self: @ContractState, account: ContractAddress) -> u256 {
            self.unclaimed_rewards.read(account) + self.compute_new_rewards(account)
        }
 
        fn claim_rewards(ref self: ContractState) {
            let user = get_caller_address();
            self.update_rewards(user);
 
            let rewards = self.unclaimed_rewards.read(user);
 
            if rewards > 0 {
                self.unclaimed_rewards.write(user, 0);
                self.reward_token.read().transfer(user, rewards);
            }
        }
    }
 
    #[generate_trait]
    impl PrivateFunctions of PrivateFunctionsTrait {
        fn update_rewards(ref self: ContractState, account: ContractAddress) {
            self
                .current_reward_per_staked_token
                .write(self.compute_current_reward_per_staked_token());
 
            self.last_updated_at.write(self.last_time_applicable());
 
            if account.is_non_zero() {
                self.distribute_user_rewards(account);
 
                self
                    .last_user_reward_per_staked_token
                    .write(account, self.current_reward_per_staked_token.read());
 
                self.send_rewards_finished_event();
            }
        }
 
        fn distribute_user_rewards(ref self: ContractState, account: ContractAddress) {
            let user_rewards = self.get_rewards(account);
            self.unclaimed_rewards.write(account, user_rewards);
 
            self
                .total_distributed_rewards
                .write(self.total_distributed_rewards.read() + user_rewards);
        }
 
        fn send_rewards_finished_event(ref self: ContractState) {
            if self.last_updated_at.read() == self.finish_at.read() {
                let total_rewards = self.reward_rate.read() * self.duration.read();
 
                if total_rewards != 0 && self.total_distributed_rewards.read() == total_rewards {
                    self.emit(RewardsFinished { msg: 'Rewards all distributed' });
                } else {
                    self.emit(RewardsFinished { msg: 'Rewards not active yet' });
                }
            }
        }
 
        fn compute_current_reward_per_staked_token(self: @ContractState) -> u256 {
            if self.total_supply.read() == 0 {
                self.current_reward_per_staked_token.read()
            } else {
                self.current_reward_per_staked_token.read()
                    + self.reward_rate.read()
                        * (self.last_time_applicable() - self.last_updated_at.read())
                        / self.total_supply.read()
            }
        }
 
        fn compute_new_rewards(self: @ContractState, account: ContractAddress) -> u256 {
            self.balance_of.read(account)
                * (self.current_reward_per_staked_token.read()
                    - self.last_user_reward_per_staked_token.read(account))
        }
 
        #[inline(always)]
        fn last_time_applicable(self: @ContractState) -> u256 {
            Self::min(self.finish_at.read(), get_block_timestamp().into())
        }
 
        #[inline(always)]
        fn min(x: u256, y: u256) -> u256 {
            if (x <= y) {
                x
            } else {
                y
            }
        }
 
        fn only_owner(self: @ContractState) {
            let caller = get_caller_address();
            assert(caller == self.owner.read(), super::Errors::NOT_OWNER);
        }
    }
}
}