diff --git a/Makefile b/Makefile index 60994e2..4708723 100644 --- a/Makefile +++ b/Makefile @@ -2,3 +2,8 @@ lint: @ruff check . --fix @ruff format . .PHONY: lint + +test: + python -m coverage run -m pytest tests/ -v && \ + python -m coverage report +.PHONY: test \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f2c6551..4862524 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,14 @@ dependencies = [ "snet-contracts==1.0.1", "lighthouseweb3~=0.1.4", "py-multihash~=3.0", + "pydantic~=2.11", "pydantic-settings~=2.13" ] [tool.poetry.group.dev.dependencies] ruff = "^0.11" pytest = "^8.3" +coverage = "^7.13" [tool.ruff] line-length = 100 diff --git a/requirements.txt b/requirements.txt index 4bc7135..5731954 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ ipfshttpclient==0.4.13.2 snet-contracts==1.0.1 lighthouseweb3~=0.1.4 py-multihash~=3.0 +pydantic~=2.11 pydantic-settings~=2.13 \ No newline at end of file diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index e574fa0..641ea08 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -7,7 +7,8 @@ import google.protobuf.internal.api_implementation from google.protobuf import symbol_database as _symbol_database -from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.registry.registry_contract import RegistryContract +from snet.sdk.registry.service_metadata import MPEServiceMetadata with warnings.catch_warnings(): # Suppress the eth-typing package`s warnings related to some new networks @@ -18,9 +19,6 @@ UserWarning, ) - import web3 - -from snet.contracts import get_contract_object from snet.sdk.account import Account from snet.sdk.config import config from snet.sdk.client_lib_generator import ClientLibGenerator @@ -34,12 +32,11 @@ PaymentStrategy, ) from snet.sdk.service_client import ServiceClient -from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.registry.storage_provider import StorageProvider from snet.sdk.custom_typing import ModuleName, ServiceStub from snet.sdk.utils.utils import ( - bytes32_to_str, find_file_by_keyword, - type_converter, + get_we3_object, ) google.protobuf.internal.api_implementation.Type = lambda: "python" @@ -55,29 +52,13 @@ class PaymentStrategyType(Enum): class SnetSDK: - """Base Snet SDK""" - def __init__(self): - self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT)) - - mpe_contract_address = config.MPE_CONTRACT_ADDRESS - if not mpe_contract_address: - self.mpe_contract = MPEContract(self.web3) - else: - self.mpe_contract = MPEContract(self.web3, mpe_contract_address) - - registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS - if registry_contract_address is None: - self.registry_contract = get_contract_object(self.web3, "Registry") - else: - self.registry_contract = get_contract_object( - self.web3, "Registry", registry_contract_address - ) - + self.w3 = get_we3_object() + self.mpe_contract = MPEContract() + self.registry_contract = RegistryContract() self.metadata_provider = StorageProvider(self.registry_contract) - - self.account = Account(self.web3, self.mpe_contract) - self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract) + self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract) + self.account = Account() def create_service_client( self, @@ -99,7 +80,7 @@ def create_service_client( if force_update: lib_generator.generate_client_library() else: - path_to_pb_files = lib_generator.protodir + path_to_pb_files = lib_generator.proto_dir pb_2_file_name = find_file_by_keyword( path_to_pb_files, keyword="pb2.py", exclude=["training"] ) @@ -134,16 +115,16 @@ def create_service_client( options, self.mpe_contract, self.account, - self.web3, + self.w3, pb2_module, self.payment_channel_provider, - lib_generator.protodir, + lib_generator.proto_dir, lib_generator.training_added(), ) return _service_client def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]: - path_to_pb_files = str(lib_generator.protodir) + path_to_pb_files = str(lib_generator.proto_dir) module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator) sys.path.append(path_to_pb_files) try: @@ -159,7 +140,7 @@ def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStu raise Exception(f"Error importing module: {e}") def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName: - path_to_pb_files = lib_generator.protodir + path_to_pb_files = lib_generator.proto_dir file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"]) module_name = os.path.splitext(file_name)[0] return ModuleName(module_name) @@ -176,7 +157,8 @@ def _get_group_by_group_name( for group in service_metadata["groups"]: if group["group_name"] == group_name: return group - return {} + # TODO: configure exceptions + raise Exception() def _get_service_group_details( self, service_metadata: MPEServiceMetadata, group_name: str @@ -190,17 +172,7 @@ def _get_service_group_details( return self._get_group_by_group_name(service_metadata, group_name) def get_organization_list(self) -> list: - org_list = self.registry_contract.functions.listOrganizations().call() - organization_list = [] - for idx, org_id in enumerate(org_list): - organization_list.append(bytes32_to_str(org_id)) - return organization_list + return self.registry_contract.list_orgs() def get_services_list(self, org_id: str) -> list: - found, org_service_list = self.registry_contract.functions.listServicesForOrganization( - type_converter("bytes32")(org_id) - ).call() - if not found: - raise Exception(f"Organization with id={org_id} doesn't exist!") - org_service_list = list(map(bytes32_to_str, org_service_list)) - return org_service_list + return self.registry_contract.list_service_for_org(org_id) diff --git a/snet/sdk/account.py b/snet/sdk/account.py index 31b404d..d8e6d12 100644 --- a/snet/sdk/account.py +++ b/snet/sdk/account.py @@ -1,12 +1,10 @@ import json -import web3 from snet.contracts import get_contract_object from snet.sdk.config import config -from snet.sdk.mpe.mpe_contract import MPEContract -from snet.sdk.utils.utils import get_address_from_private, normalize_private_key +from snet.sdk.utils.utils import get_address_from_private, normalize_private_key, get_we3_object DEFAULT_GAS = 300000 TRANSACTION_TIMEOUT = 500 @@ -28,17 +26,13 @@ def __str__(self): class Account: - def __init__(self, w3: web3.Web3, mpe_contract: MPEContract): - self.web3 = w3 - self.mpe_contract = mpe_contract + def __init__(self): + self.w3 = get_we3_object() + self.mpe_address = config.MPE_CONTRACT_ADDRESS - token_contract_address = config.TOKEN_CONTRACT_ADDRESS - if not token_contract_address: - self.token_contract = get_contract_object(self.web3, "FetchToken") - else: - self.token_contract = get_contract_object( - self.web3, "FetchToken", token_contract_address - ) + self.token_contract = get_contract_object( + self.w3, "FetchToken", config.TOKEN_CONTRACT_ADDRESS + ) if config.PRIVATE_KEY: self.private_key = normalize_private_key(config.PRIVATE_KEY) @@ -52,19 +46,19 @@ def __init__(self, w3: web3.Web3, mpe_contract: MPEContract): self.nonce = 0 def _get_nonce(self): - nonce = self.web3.eth.get_transaction_count(self.address) + nonce = self.w3.eth.get_transaction_count(self.address) if self.nonce >= nonce: nonce = self.nonce + 1 self.nonce = nonce return nonce def _get_gas_price(self): - gas_price = self.web3.eth.gas_price + gas_price = self.w3.eth.gas_price if gas_price <= 15000000000: gas_price += gas_price * 1 / 3 - elif gas_price > 15000000000 and gas_price <= 50000000000: + elif 15000000000 < gas_price <= 50000000000: gas_price += gas_price * 1 / 5 - elif gas_price > 50000000000 and gas_price <= 150000000000: + elif 50000000000 < gas_price <= 150000000000: gas_price += 7000000000 elif gas_price > 150000000000: gas_price += gas_price * 1 / 10 @@ -73,20 +67,18 @@ def _get_gas_price(self): def _send_signed_transaction(self, contract_fn, *args): transaction = contract_fn(*args).build_transaction( { - "chainId": int(self.web3.net.version), + "chainId": int(self.w3.net.version), "gas": DEFAULT_GAS, "gasPrice": self._get_gas_price(), "nonce": self._get_nonce(), } ) - signed_txn = self.web3.eth.account.sign_transaction( - transaction, private_key=self.private_key - ) - return self.web3.to_hex(self.web3.eth.send_raw_transaction(signed_txn.raw_transaction)) + signed_txn = self.w3.eth.account.sign_transaction(transaction, private_key=self.private_key) + return self.w3.to_hex(self.w3.eth.send_raw_transaction(signed_txn.raw_transaction)) def send_transaction(self, contract_fn, *args): txn_hash = self._send_signed_transaction(contract_fn, *args) - return self.web3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT) + return self.w3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT) def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder): if receipt.status == 0: @@ -94,23 +86,12 @@ def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder): else: return json.dumps(dict(event().processReceipt(receipt)[0]["args"]), cls=encoder) - def escrow_balance(self): - return self.mpe_contract.balance(self.address) - - def deposit_to_escrow_account(self, amount_in_cogs): - already_approved = self.allowance() - if amount_in_cogs > already_approved: - self.approve_transfer(amount_in_cogs) - return self.mpe_contract.deposit(self, amount_in_cogs) - def approve_transfer(self, amount_in_cogs): return self.send_transaction( self.token_contract.functions.approve, - self.mpe_contract.contract.address, + self.mpe_address, amount_in_cogs, ) def allowance(self): - return self.token_contract.functions.allowance( - self.address, self.mpe_contract.contract.address - ).call() + return self.token_contract.functions.allowance(self.address, self.mpe_address).call() diff --git a/snet/sdk/client_lib_generator.py b/snet/sdk/client_lib_generator.py index eb485c1..bd58ec4 100644 --- a/snet/sdk/client_lib_generator.py +++ b/snet/sdk/client_lib_generator.py @@ -1,7 +1,7 @@ import os from pathlib import Path -from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.registry.storage_provider import StorageProvider from snet.sdk.utils.utils import compile_proto @@ -11,21 +11,21 @@ def __init__( metadata_provider: StorageProvider, org_id: str, service_id: str, - protodir: Path | None = None, + proto_dir: Path | None = None, ): self._metadata_provider: StorageProvider = metadata_provider self.org_id: str = org_id self.service_id: str = service_id self.language: str = "python" - self.protodir: Path = protodir if protodir else Path.home().joinpath(".snet") + self.proto_dir: Path = proto_dir if proto_dir else Path.home().joinpath(".snet") self.generate_directories_by_params() def generate_client_library(self) -> None: try: self.receive_proto_files() compilation_result = compile_proto( - entry_path=self.protodir, - codegen_dir=self.protodir, + entry_path=self.proto_dir, + codegen_dir=self.proto_dir, target_language=self.language, add_training=self.training_added(), ) @@ -33,19 +33,19 @@ def generate_client_library(self) -> None: print( f'client libraries for service with id "{self.service_id}" ' f'in org with id "{self.org_id}" ' - f"generated at {self.protodir}" + f"generated at {self.proto_dir}" ) except Exception as e: print(str(e)) def generate_directories_by_params(self) -> None: - if not self.protodir.is_absolute(): - self.protodir = Path.cwd().joinpath(self.protodir) + if not self.proto_dir.is_absolute(): + self.proto_dir = Path.cwd().joinpath(self.proto_dir) self.create_service_client_libraries_path() def create_service_client_libraries_path(self) -> None: - self.protodir = self.protodir.joinpath(self.org_id, self.service_id, self.language) - self.protodir.mkdir(parents=True, exist_ok=True) + self.proto_dir = self.proto_dir.joinpath(self.org_id, self.service_id, self.language) + self.proto_dir.mkdir(parents=True, exist_ok=True) def receive_proto_files(self) -> None: metadata = self._metadata_provider.fetch_service_metadata( @@ -54,17 +54,17 @@ def receive_proto_files(self) -> None: service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash") # Receive proto files - if self.protodir.exists(): - self._metadata_provider.fetch_and_extract_proto(service_api_source, self.protodir) + if self.proto_dir.exists(): + self._metadata_provider.fetch_and_extract_proto(service_api_source, self.proto_dir) else: raise Exception("Directory for storing proto files is not found") def training_added(self) -> bool: - files = os.listdir(self.protodir) + files = os.listdir(self.proto_dir) for file in files: if ".proto" not in file: continue - with open(self.protodir.joinpath(file), "r") as f: + with open(self.proto_dir.joinpath(file), "r") as f: proto_text = f.read() if 'import "training.proto";' in proto_text: return True diff --git a/snet/sdk/mpe/mpe_contract.py b/snet/sdk/mpe/mpe_contract.py index 3823da1..2a847f6 100644 --- a/snet/sdk/mpe/mpe_contract.py +++ b/snet/sdk/mpe/mpe_contract.py @@ -1,21 +1,31 @@ +from typing import Optional + from snet.contracts import get_contract_object +from snet.sdk.account import Account +from snet.sdk.config import config +from snet.sdk.utils.utils import get_we3_object + class MPEContract: - def __init__(self, w3, address=None): - self.web3 = w3 - if address is None: - self.contract = get_contract_object(self.web3, "MultiPartyEscrow") - else: - self.contract = get_contract_object(self.web3, "MultiPartyEscrow", address) - - def balance(self, address): + def __init__(self): + self.w3 = get_we3_object() + self.contract = get_contract_object( + self.w3, "MultiPartyEscrow", config.MPE_CONTRACT_ADDRESS + ) + + def balance(self, account: Account, address: Optional[str] = None): + if not address: + address = account.address return self.contract.functions.balances(address).call() def deposit(self, account, amount_in_cogs): + already_approved = account.allowance() + if amount_in_cogs > already_approved: + account.approve_transfer(amount_in_cogs) return account.send_transaction(self.contract.functions.deposit, amount_in_cogs) - def open_channel(self, account, payment_address, group_id, amount, expiration): + def open_channel(self, account: Account, payment_address, group_id, amount, expiration): return account.send_transaction( self.contract.functions.openChannel, account.signer_address, @@ -25,7 +35,9 @@ def open_channel(self, account, payment_address, group_id, amount, expiration): expiration, ) - def deposit_and_open_channel(self, account, payment_address, group_id, amount, expiration): + def deposit_and_open_channel( + self, account: Account, payment_address, group_id, amount, expiration + ): already_approved_amount = account.allowance() if amount > already_approved_amount: account.approve_transfer(amount) @@ -38,16 +50,16 @@ def deposit_and_open_channel(self, account, payment_address, group_id, amount, e expiration, ) - def channel_add_funds(self, account, channel_id, amount): + def channel_add_funds(self, account: Account, channel_id, amount): self._fund_escrow_account(account, amount) return account.send_transaction(self.contract.functions.channelAddFunds, channel_id, amount) - def channel_extend(self, account, channel_id, expiration): + def channel_extend(self, account: Account, channel_id, expiration): return account.send_transaction( self.contract.functions.channelExtend, channel_id, expiration ) - def channel_extend_and_add_funds(self, account, channel_id, expiration, amount): + def channel_extend_and_add_funds(self, account: Account, channel_id, expiration, amount): self._fund_escrow_account(account, amount) return account.send_transaction( self.contract.functions.channelExtendAndAddFunds, @@ -56,7 +68,7 @@ def channel_extend_and_add_funds(self, account, channel_id, expiration, amount): amount, ) - def _fund_escrow_account(self, account, amount): + def _fund_escrow_account(self, account: Account, amount): current_escrow_balance = self.balance(account.address) if amount > current_escrow_balance: - account.deposit_to_escrow_account(amount - current_escrow_balance) + self.deposit(amount - current_escrow_balance) diff --git a/snet/sdk/mpe/payment_channel.py b/snet/sdk/mpe/payment_channel.py index f413482..616c70e 100644 --- a/snet/sdk/mpe/payment_channel.py +++ b/snet/sdk/mpe/payment_channel.py @@ -2,20 +2,19 @@ import importlib from eth_account.messages import defunct_hash_message -from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path +from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path, get_we3_object class PaymentChannel: def __init__( self, channel_id, - w3, account, payment_channel_state_service_client, mpe_contract, ): self.channel_id = channel_id - self.web3 = w3 + self.web3 = get_we3_object() self.account = account self.mpe_contract = mpe_contract self.payment_channel_state_service_client = payment_channel_state_service_client diff --git a/snet/sdk/mpe/payment_channel_provider.py b/snet/sdk/mpe/payment_channel_provider.py index 40b05c7..0a12fff 100644 --- a/snet/sdk/mpe/payment_channel_provider.py +++ b/snet/sdk/mpe/payment_channel_provider.py @@ -4,8 +4,10 @@ from web3.types import LogReceipt +from snet.sdk.utils.utils import get_we3_object from snet.sdk.mpe.payment_channel import PaymentChannel from snet.contracts import get_contract_deployment_block +from snet.sdk.mpe.mpe_contract import MPEContract BLOCKS_PER_BATCH = 50000 @@ -13,8 +15,8 @@ class PaymentChannelProvider(object): - def __init__(self, w3, mpe_contract): - self.web3 = w3 + def __init__(self, mpe_contract: MPEContract): + self.web3 = get_we3_object() self.mpe_contract = mpe_contract self.event_topics = [ @@ -138,7 +140,6 @@ def get_past_open_channels( map( lambda channel: PaymentChannel( channel["channel_id"], - self.web3, account, payment_channel_state_service_client, self.mpe_contract, diff --git a/snet/sdk/payment_strategies/paidcall_payment_strategy.py b/snet/sdk/payment_strategies/paidcall_payment_strategy.py index 534a21e..a7127c4 100644 --- a/snet/sdk/payment_strategies/paidcall_payment_strategy.py +++ b/snet/sdk/payment_strategies/paidcall_payment_strategy.py @@ -36,13 +36,12 @@ def get_payment_metadata(self, service_client): return metadata def select_channel(self, service_client): - account = service_client.account service_client.load_open_channels() service_client.update_channel_states() payment_channels = service_client.payment_channels # picking the first pricing strategy as default for now service_call_price = self.get_price(service_client) - mpe_balance = account.escrow_balance() + mpe_balance = service_client.get_mpe_balance() default_expiration = service_client.default_channel_expiration() if len(payment_channels) < 1: diff --git a/snet/sdk/payment_strategies/prepaid_payment_strategy.py b/snet/sdk/payment_strategies/prepaid_payment_strategy.py index 5a61087..c2e3ce0 100644 --- a/snet/sdk/payment_strategies/prepaid_payment_strategy.py +++ b/snet/sdk/payment_strategies/prepaid_payment_strategy.py @@ -40,13 +40,12 @@ def get_concurrency_token_and_channel(self, service_client): return token, channel def select_channel(self, service_client): - account = service_client.account service_client.load_open_channels() service_client.update_channel_states() payment_channels = service_client.payment_channels service_call_price = self.get_price(service_client) extend_channel_fund = service_call_price * self.call_allowance - mpe_balance = account.escrow_balance() + mpe_balance = service_client.get_mpe_balance() default_expiration = service_client.default_channel_expiration() if len(payment_channels) < 1: diff --git a/snet/sdk/storage_provider/__init__.py b/snet/sdk/registry/__init__.py similarity index 100% rename from snet/sdk/storage_provider/__init__.py rename to snet/sdk/registry/__init__.py diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py new file mode 100644 index 0000000..50bd0be --- /dev/null +++ b/snet/sdk/registry/registry_contract.py @@ -0,0 +1,97 @@ +from typing import Union + +from snet.contracts import get_contract_object + +from snet.sdk.account import Account +from snet.sdk.config import config +from snet.sdk.types import RawOrgData, OrgData, ServiceData, RawServiceData +from snet.sdk.utils.utils import ( + type_converter, + bytes32_to_str, + get_we3_object, + convert_raw_service_data, + convert_raw_org_data, +) + + +class RegistryContract: + def __init__(self): + self.w3 = get_we3_object() + self.contract = get_contract_object(self.w3, "Registry", config.REGISTRY_CONTRACT_ADDRESS) + + # READ METHODS + + def get_org(self, org_id: str) -> OrgData: + found, found_org_id, org_metadata_uri, owner, members, service_ids = ( + self.contract.functions.getOrganizationById(type_converter("bytes32")(org_id)).call() + ) + if not found: + # TODO: configure exceptions + raise Exception() + + return convert_raw_org_data( + RawOrgData( + org_id=found_org_id, + metadata_uri=org_metadata_uri, + owner=owner, + members=members, + services=service_ids, + ) + ) + + def get_service(self, org_id: str, service_id: str) -> ServiceData: + found, found_service_id, service_metadata_uri = ( + self.contract.functions.getServiceRegistrationById( + type_converter("bytes32")(org_id), type_converter("bytes32")(service_id) + ).call() + ) + if not found: + # TODO: configure exceptions + raise Exception() + + return convert_raw_service_data( + RawServiceData(service_id=found_service_id, metadata_uri=service_metadata_uri), + org_id=org_id, + ) + + def list_orgs(self) -> list[str]: + org_list = self.contract.functions.listOrganizations().call() + return list(map(bytes32_to_str, org_list)) + + def list_service_for_org(self, org_id: str) -> list[str]: + found, org_service_list = self.contract.functions.listServicesForOrganization( + type_converter("bytes32")(org_id) + ).call() + if not found: + # TODO: configure exceptions + raise Exception() + else: + return list(map(bytes32_to_str, org_service_list)) + + # WRITE METHODS + + def add_org_members( + self, account: Account, org_id: str, members: Union[str, list[str], None] + ): ... + + def update_org_metadata(self, account: Account, org_id: str, metadata_uri: str): ... + + def change_org_owner(self, account: Account, org_id: str, new_owner: str): ... + + def create_org( + self, account: Account, org_id: str, metadata_uri: str, members: Union[str, list[str], None] + ): ... + + def create_service(self, account: Account, org_id: str, service_id: str, metadata_uri: str): ... + + def delete_org(self, account: Account, org_id: str): ... + + def delete_service(self, account: Account, org_id: str, service_id: str): ... + + def remove_org_members( + self, account: Account, org_id: str, members_to_remove: Union[str, list[str]] + ): ... + + def update_service_metadata( + self, account: Account, org_id: str, service_id: str, metadata_uri: str + ): ... diff --git a/snet/sdk/storage_provider/service_metadata.py b/snet/sdk/registry/service_metadata.py similarity index 90% rename from snet/sdk/storage_provider/service_metadata.py rename to snet/sdk/registry/service_metadata.py index 9ab11b1..c734029 100644 --- a/snet/sdk/storage_provider/service_metadata.py +++ b/snet/sdk/registry/service_metadata.py @@ -39,9 +39,13 @@ import re import json import base64 +import secrets from collections import defaultdict from enum import Enum +from typing import Literal, Any, Optional + +from pydantic import BaseModel, Field, model_validator, ValidationInfo from snet.sdk.utils.utils import is_valid_endpoint @@ -63,6 +67,75 @@ def is_single_value(asset_type): return True +def generate_group_id() -> str: + return base64.b64encode(secrets.token_bytes(32)).decode() + + +class Pricing(BaseModel): + price_model: Literal["fixed_price", "method_price"] = Field(default="fixed_price") + price_in_cogs: int = Field(ge=1) + default: bool = Field(default=True) + + +class Group(BaseModel): + group_name: str = Field(min_length=1, default="default_group") + group_id: str = Field(default_factory=generate_group_id, init=False) + free_calls: int = Field(ge=1) + free_call_signer_address: str + daemon_addresses: list[str] + + +class ServiceDescription(BaseModel): ... + + +class Media(BaseModel): ... + + +class Contributor(BaseModel): ... + + +class ServiceMetadata(BaseModel): + version: int = Field(ge=1, default=1) + display_name: str + encoding: Literal["proto", "json"] = Field(default="proto") + service_type: Literal["grpc", "http", "jsonrpc"] + service_api_source: Optional[str] = Field(min_length=1, default=None, init=False) + model_ipfs_hash: Optional[str] = Field(min_length=1, default=None, init=False, deprecated=True) + mpe_address: str + groups: list[Group] + service_description: ServiceDescription + media: list[Media] + contributors: list[Contributor] + tags: list[str] + + @model_validator(mode="before") + @classmethod + def restrict_deprecated_fields(cls, data: Any, info: ValidationInfo) -> Any: + if not isinstance(data, dict): + return data + + is_fetching = info.context and info.context.get("from_storage") is True + + if not is_fetching and data.get("model_ipfs_hash"): + # TODO: configure exceptions + raise ValueError( + "The 'model_ipfs_hash' field is deprecated and cannot be used " + "to create new metadata. Please use 'service_api_source' instead." + ) + + return data + + def generate_final_json(self): + if self.service_api_source is None: + if self.model_ipfs_hash is None: + # TODO: configure exceptions + raise ValueError("The 'service_api_source' field is missing!") + else: + self.service_api_source, self.model_ipfs_hash = self.model_ipfs_hash, None + + return self.model_dump_json(indent=2, exclude_none=True) + + # TODO: we should use some standard solution here class MPEServiceMetadata: def __init__(self): diff --git a/snet/sdk/registry/storage_provider.py b/snet/sdk/registry/storage_provider.py new file mode 100644 index 0000000..f11d698 --- /dev/null +++ b/snet/sdk/registry/storage_provider.py @@ -0,0 +1,76 @@ +from typing import Any + +from lighthouseweb3 import Lighthouse +import json + +from snet.sdk.registry.registry_contract import RegistryContract +from snet.sdk.types import StorageType, FileURI +from snet.sdk.utils.ipfs_utils import ( + get_ipfs_client, + get_from_ipfs_and_checkhash, +) +from snet.sdk.utils.utils import bytesuri_to_hash, safe_extract_proto +from snet.sdk.registry.service_metadata import ( + MPEServiceMetadata, + mpe_service_metadata_from_json, +) +from snet.sdk.config import config + + +class StorageProvider(object): + def __init__(self, registry_contract: RegistryContract): + self._registry_contract = registry_contract + self._ipfs_client = get_ipfs_client() + self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) + + def fetch_org_metadata(self, org_id): + org = self._registry_contract.get_org(org_id) + + org_metadata_json = self._get_from_storage(org.metadata_uri) + org_metadata = json.loads(org_metadata_json) + + return org_metadata + + def fetch_service_metadata(self, org_id: str, service_id: str) -> MPEServiceMetadata: + service = self._registry_contract.get_service(org_id, service_id) + + service_metadata_json = self._get_from_storage(service.metadata_uri) + service_metadata = mpe_service_metadata_from_json(service_metadata_json) + + return service_metadata + + def enhance_service_metadata(self, org_id, service_id): + service_metadata = self.fetch_service_metadata(org_id, service_id) + org_metadata = self.fetch_org_metadata(org_id) + + org_group_map = {} + for group in org_metadata["groups"]: + org_group_map[group["group_name"]] = group + + for group in service_metadata.m["groups"]: + # merge service group with org_group + group["payment"] = org_group_map[group["group_name"]]["payment"] + + return service_metadata + + def fetch_and_extract_proto(self, service_api_source, proto_dir): + try: + tar_uri = bytesuri_to_hash(service_api_source, to_decode=False) + except Exception: + # TODO: change exception based on bytesuri_to_hash function + tar_uri = FileURI(storage_type=StorageType.IPFS, uri_hash=service_api_source) + + spec_tar = self._get_from_storage(tar_uri) + + safe_extract_proto(spec_tar, proto_dir) + + def _get_from_storage(self, uri: FileURI) -> Any: + if uri.storage_type == StorageType.IPFS: + file = get_from_ipfs_and_checkhash(self._ipfs_client, uri.uri_hash) + elif uri.storage_type == StorageType.FILECOIN: + file, _ = self.lighthouse_client.download(uri.uri_hash) + else: + # TODO: configure exceptions + raise Exception() + + return file diff --git a/snet/sdk/service_client.py b/snet/sdk/service_client.py index 0ffcb93..c0faa87 100644 --- a/snet/sdk/service_client.py +++ b/snet/sdk/service_client.py @@ -21,7 +21,7 @@ PrePaidPaymentStrategy, ) from snet.sdk.resources.root_certificate import certificate -from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.registry.service_metadata import MPEServiceMetadata from snet.sdk.custom_typing import ModuleName, ServiceStub from snet.sdk.utils.utils import ( RESOURCES_PATH, @@ -59,6 +59,7 @@ def __init__( if isinstance(payment_strategy, PrePaidPaymentStrategy): self.payment_strategy.set_concurrent_calls(options["concurrent_calls"]) self.options = options + self.mpe_contract = mpe_contract self.mpe_address = mpe_contract.contract.address self.account = account self.sdk_web3 = sdk_web3 @@ -193,6 +194,9 @@ def _generate_payment_channel_state_service_client(self) -> Any: state_service = importlib.import_module("state_service_pb2_grpc") return state_service.PaymentChannelStateServiceStub(grpc_channel) + def get_mpe_balance(self): + return self.mpe_contract.balance(self.account) + def open_channel(self, amount: int, expiration: int) -> PaymentChannel: payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) diff --git a/snet/sdk/storage_provider/storage_provider.py b/snet/sdk/storage_provider/storage_provider.py deleted file mode 100644 index edfdec1..0000000 --- a/snet/sdk/storage_provider/storage_provider.py +++ /dev/null @@ -1,91 +0,0 @@ -import web3 -from lighthouseweb3 import Lighthouse -import json - -from snet.sdk.utils.ipfs_utils import ( - get_ipfs_client, - get_from_ipfs_and_checkhash, -) -from snet.sdk.utils.utils import bytesuri_to_hash, safe_extract_proto -from snet.sdk.storage_provider.service_metadata import ( - MPEServiceMetadata, - mpe_service_metadata_from_json, -) -from snet.sdk.config import config - - -class StorageProvider(object): - def __init__(self, registry_contract): - self._registry_contract = registry_contract - self._ipfs_client = get_ipfs_client() - self.lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) - - def fetch_org_metadata(self, org_id): - org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0") - - found, _, org_metadata_uri, _, _, _ = self._registry_contract.functions.getOrganizationById( - org - ).call() - if found is not True: - raise Exception('Organization with org ID "{}" not found '.format(org_id)) - - org_provider_type, org_metadata_hash = bytesuri_to_hash(org_metadata_uri) - - if org_provider_type == "ipfs": - org_metadata_json = get_from_ipfs_and_checkhash(self._ipfs_client, org_metadata_hash) - else: - org_metadata_json, _ = self.lighthouse_client.download(org_metadata_hash) - org_metadata = json.loads(org_metadata_json) - - return org_metadata - - def fetch_service_metadata(self, org_id: str, service_id: str) -> MPEServiceMetadata: - org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0") - service = web3.Web3.to_bytes(text=service_id).ljust(32, b"\0") - - found, _, service_metadata_uri = ( - self._registry_contract.functions.getServiceRegistrationById(org, service).call() - ) - if found is not True: - raise Exception(f"No service '{service_id}' found in organization '{org_id}'") - - service_provider_type, service_metadata_hash = bytesuri_to_hash(s=service_metadata_uri) - - if service_provider_type == "ipfs": - service_metadata_json = get_from_ipfs_and_checkhash( - self._ipfs_client, service_metadata_hash - ) - else: - service_metadata_json, _ = self.lighthouse_client.download(cid=service_metadata_hash) - service_metadata = mpe_service_metadata_from_json(service_metadata_json) - - return service_metadata - - def enhance_service_metadata(self, org_id, service_id): - service_metadata = self.fetch_service_metadata(org_id, service_id) - org_metadata = self.fetch_org_metadata(org_id) - - org_group_map = {} - for group in org_metadata["groups"]: - org_group_map[group["group_name"]] = group - - for group in service_metadata.m["groups"]: - # merge service group with org_group - group["payment"] = org_group_map[group["group_name"]]["payment"] - - return service_metadata - - def fetch_and_extract_proto(self, service_api_source, protodir): - try: - proto_provider_type, service_api_source = bytesuri_to_hash( - service_api_source, to_decode=False - ) - except Exception: - proto_provider_type = "ipfs" - - if proto_provider_type == "ipfs": - spec_tar = get_from_ipfs_and_checkhash(self._ipfs_client, service_api_source) - else: - spec_tar, _ = self.lighthouse_client.download(service_api_source) - - safe_extract_proto(spec_tar, protodir) diff --git a/snet/sdk/types.py b/snet/sdk/types.py new file mode 100644 index 0000000..216ed18 --- /dev/null +++ b/snet/sdk/types.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from enum import Enum + + +@dataclass +class RawOrgData: + org_id: bytes + metadata_uri: bytes + owner: str + members: list[str] + services: list[bytes] # IDs + + +@dataclass +class RawServiceData: + service_id: bytes + metadata_uri: bytes + + +class StorageType(Enum): + IPFS = "ipfs" + FILECOIN = "filecoin" + + +@dataclass +class FileURI: + storage_type: StorageType + uri_hash: str + + +@dataclass +class OrgData: + org_id: str + metadata_uri: FileURI + owner: str + members: list[str] + services: list[str] # IDs + + +@dataclass +class ServiceData: + org_id: str + service_id: str + metadata_uri: FileURI diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 0f2d15b..9e4c876 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -1,6 +1,8 @@ import json import sys import importlib.resources +from functools import lru_cache +from typing import Optional, Union from urllib.parse import urlparse from pathlib import Path, PurePath import os @@ -10,19 +12,50 @@ import web3 from eth_typing import BlockNumber from grpc_tools.protoc import main as protoc +from web3 import Web3 from snet import sdk +from snet.sdk.config import config +from snet.sdk.types import StorageType, FileURI, RawOrgData, OrgData, RawServiceData, ServiceData RESOURCES_PATH = PurePath(os.path.dirname(sdk.__file__)).joinpath("resources") +def convert_raw_org_data(raw_org_data: RawOrgData) -> OrgData: + return OrgData( + org_id=bytes32_to_str(raw_org_data.org_id), + metadata_uri=bytesuri_to_hash(raw_org_data.metadata_uri), + owner=raw_org_data.owner, + members=raw_org_data.members, + services=list(map(bytes32_to_str, raw_org_data.services)), + ) + + +def convert_raw_service_data( + raw_service_data: RawServiceData, org_id: Union[str, bytes] +) -> ServiceData: + return ServiceData( + org_id=bytes32_to_str(org_id) if isinstance(org_id, bytes) else org_id, + service_id=bytes32_to_str(raw_service_data.service_id), + metadata_uri=bytesuri_to_hash(raw_service_data.metadata_uri), + ) + + +@lru_cache +def get_we3_object(eth_rpc_endpoint: Optional[str] = None) -> Web3: + if eth_rpc_endpoint is None: + eth_rpc_endpoint = config.ETH_RPC_ENDPOINT + + return web3.Web3(web3.HTTPProvider(eth_rpc_endpoint)) + + def safe_address_converter(a): if not web3.Web3.is_checksum_address(a): raise Exception("%s is not is not a valid Ethereum checksum address" % a) return a -def type_converter(t): +def type_converter(t: str): if t.endswith("[]"): return lambda x: list(map(type_converter(t.replace("[]", "")), json.loads(x))) else: @@ -137,7 +170,7 @@ def get_address_from_private(private_key): def get_current_block_number() -> BlockNumber: - return web3.Web3().eth.block_number + return get_we3_object().eth.block_number class add_to_path: @@ -166,11 +199,11 @@ def find_file_by_keyword(directory, keyword, exclude=None): def bytesuri_to_hash(s, to_decode=True): if to_decode: s = s.rstrip(b"\0").decode("ascii") - if s.startswith("ipfs://"): - return "ipfs", s[7:] - elif s.startswith("filecoin://"): - return "filecoin", s[11:] - else: + try: + storage_type, storage_hash = s.split("://") + return FileURI(StorageType(storage_type), storage_hash) + except ValueError: + # TODO: configure exceptions raise Exception("We support only ipfs and filecoin uri in Registry") diff --git a/tests/unit_tests/test_lib_generator.py b/tests/unit_tests/test_lib_generator.py index ec0aaeb..1bcfac0 100644 --- a/tests/unit_tests/test_lib_generator.py +++ b/tests/unit_tests/test_lib_generator.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, patch from snet.sdk.client_lib_generator import ClientLibGenerator -from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.registry.storage_provider import StorageProvider class TestClientLibGenerator(unittest.TestCase): @@ -18,29 +18,29 @@ def setUp(self): metadata_provider=self.mock_metadata_provider, org_id=self.org_id, service_id=self.service_id, - protodir=self.protodir, + proto_dir=self.protodir, ) @patch("pathlib.Path.mkdir") def test_generate_directories_by_params_by_absolute_path(self, mock_mkdir): expected_library_dir = self.protodir.joinpath(self.org_id, self.service_id, self.language) self.generator.generate_directories_by_params() - self.assertEqual(self.generator.protodir, expected_library_dir) + self.assertEqual(self.generator.proto_dir, expected_library_dir) mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) @patch("pathlib.Path.mkdir") def test_generate_directories_by_params_by_relative_path(self, mock_mkdir): - self.generator.protodir = Path(".snet_test") + self.generator.proto_dir = Path(".snet_test") expected_library_dir = Path.cwd().joinpath( - self.generator.protodir, self.org_id, self.service_id, self.language + self.generator.proto_dir, self.org_id, self.service_id, self.language ) self.generator.generate_directories_by_params() - self.assertEqual(self.generator.protodir, expected_library_dir) + self.assertEqual(self.generator.proto_dir, expected_library_dir) mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) def test_create_service_client_libraries_path(self): mock_protodir = Mock(spec=Path) - self.generator.protodir = mock_protodir + self.generator.proto_dir = mock_protodir mock_library_path = Mock(spec=Path) mock_protodir.joinpath.return_value = mock_library_path @@ -52,7 +52,7 @@ def test_create_service_client_libraries_path(self): mock_library_path.mkdir.assert_called_once_with(parents=True, exist_ok=True) # Assert that the protodir is updated correctly - self.assertEqual(self.generator.protodir, mock_library_path) + self.assertEqual(self.generator.proto_dir, mock_library_path) def test_receive_proto_files_success(self): # Set up mocks @@ -61,8 +61,8 @@ def test_receive_proto_files_success(self): "model_ipfs_hash": os.getenv("MODEL_IPFS_HASH"), } self.mock_metadata_provider.fetch_service_metadata.return_value = mock_metadata - self.generator.protodir = Mock() - self.generator.protodir.exists.return_value = True + self.generator.proto_dir = Mock() + self.generator.proto_dir.exists.return_value = True # Call the method self.generator.receive_proto_files() @@ -75,12 +75,12 @@ def test_receive_proto_files_success(self): org_id=self.org_id, service_id=self.service_id ) self.mock_metadata_provider.fetch_and_extract_proto.assert_called_once_with( - service_api_source, self.generator.protodir + service_api_source, self.generator.proto_dir ) def test_receive_proto_files_failed(self): - self.generator.protodir = Mock() - self.generator.protodir.exists.return_value = False + self.generator.proto_dir = Mock() + self.generator.proto_dir.exists.return_value = False with self.assertRaises(Exception) as context: self.generator.receive_proto_files() diff --git a/tests/unit_tests/test_service_client.py b/tests/unit_tests/test_service_client.py index ee53738..f2c8e1d 100644 --- a/tests/unit_tests/test_service_client.py +++ b/tests/unit_tests/test_service_client.py @@ -8,7 +8,7 @@ from snet.sdk.mpe.mpe_contract import MPEContract from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider from snet.sdk.service_client import ServiceClient -from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.registry.service_metadata import MPEServiceMetadata class TestServiceClient(unittest.TestCase):