diff --git a/tests/test_tos.py b/tests/test_tos.py index a680d7cc..6490d89d 100644 --- a/tests/test_tos.py +++ b/tests/test_tos.py @@ -67,7 +67,9 @@ def __init__( type("S", (), {"Storage_Class_Standard": "STANDARD"}), ) monkeypatch.setattr( - tos_mod.tos, "ACLType", type("A", (), {"ACL_Private": "private"}) + tos_mod.tos, + "ACLType", + type("A", (), {"ACL_Private": "private", "ACL_Public_Read": "public-read"}), ) return fake_client diff --git a/veadk/integrations/ve_tos/ve_tos.py b/veadk/integrations/ve_tos/ve_tos.py index 9d49f6c8..89162d16 100644 --- a/veadk/integrations/ve_tos/ve_tos.py +++ b/veadk/integrations/ve_tos/ve_tos.py @@ -72,29 +72,74 @@ def model_post_init(self, __context: Any) -> None: logger.error(f"Client initialization failed:{e}") self._client = None + def _refresh_client(self): + try: + if self._client: + self._client.close() + self._client = tos.TosClientV2( + self.config.ak, + self.config.sk, + endpoint=f"tos-{self.config.region}.volces.com", + region=self.config.region, + ) + logger.info("refreshed client successfully.") + except Exception as e: + logger.error(f"Failed to refresh client: {str(e)}") + self._client = None + def create_bucket(self) -> bool: - """If the bucket does not exist, create it""" + """If the bucket does not exist, create it and set CORS rules""" if not self._client: logger.error("TOS client is not initialized") return False try: self._client.head_bucket(self.config.bucket_name) logger.info(f"Bucket {self.config.bucket_name} already exists") - return True except tos.exceptions.TosServerError as e: if e.status_code == 404: - self._client.create_bucket( - bucket=self.config.bucket_name, - storage_class=tos.StorageClassType.Storage_Class_Standard, - acl=tos.ACLType.ACL_Private, - ) - logger.info(f"Bucket {self.config.bucket_name} created successfully") - return True + try: + self._client.create_bucket( + bucket=self.config.bucket_name, + storage_class=tos.StorageClassType.Storage_Class_Standard, + acl=tos.ACLType.ACL_Public_Read, # 公开读 + ) + logger.info( + f"Bucket {self.config.bucket_name} created successfully" + ) + self._refresh_client() + except Exception as create_error: + logger.error(f"Bucket creation failed: {str(create_error)}") + return False else: - logger.error(f"Bucket creation failed: {str(e)}") + logger.error(f"Bucket check failed: {str(e)}") return False except Exception as e: - logger.error(f"Bucket creation failed: {str(e)}") + logger.error(f"Bucket check failed: {str(e)}") + return False + + # 确保在所有路径上返回布尔值 + return self._set_cors_rules() + + def _set_cors_rules(self) -> bool: + if not self._client: + logger.error("TOS client is not initialized") + return False + try: + rule = tos.models2.CORSRule( + allowed_origins=["*"], + allowed_methods=["GET", "HEAD"], + allowed_headers=["*"], + max_age_seconds=1000, + ) + self._client.put_bucket_cors(self.config.bucket_name, [rule]) + logger.info( + f"CORS rules for bucket {self.config.bucket_name} set successfully" + ) + return True + except Exception as e: + logger.error( + f"Failed to set CORS rules for bucket {self.config.bucket_name}: {str(e)}" + ) return False def build_tos_url( @@ -153,7 +198,6 @@ def _do_upload_file(self, object_key: str, file_path: str) -> None: return if not self.create_bucket(): return - self._client.put_object_from_file( bucket=self.config.bucket_name, key=object_key, file_path=file_path )